From 9899e6b9992b97de9d40df1dcc38b82d9165ec3f Mon Sep 17 00:00:00 2001 From: pdmurray Date: Wed, 2 Oct 2024 12:26:30 -0700 Subject: [PATCH 01/17] Apply fixes for linting issues --- .bazelversion | 2 +- .gitignore | 2 + LICENSE | 2 +- g3doc/api_docs/python/expression_impl.md | 1 - .../python/expression_impl/all_symbols.md | 2 +- .../python/expression_impl/apply_schema.md | 1 - .../python/expression_impl/broadcast.md | 1 - .../python/expression_impl/depth_limit.md | 1 - .../expression_impl/filter_expression.md | 1 - .../filter_expression/filter_by_child.md | 1 - .../filter_expression/filter_by_sibling.md | 1 - .../api_docs/python/expression_impl/index.md | 1 - .../index/get_index_from_end.md | 1 - .../index/get_positional_index.md | 1 - .../python/expression_impl/map_prensor.md | 1 - .../map_prensor/map_ragged_tensor.md | 1 - .../map_prensor/map_sparse_tensor.md | 1 - .../expression_impl/map_prensor_to_prensor.md | 1 - .../map_prensor_to_prensor/Schema.md | 6 -- .../map_prensor_to_prensor/create_schema.md | 1 - .../map_prensor_to_prensor.md | 1 - .../python/expression_impl/map_values.md | 1 - .../map_values/map_many_values.md | 1 - .../expression_impl/map_values/map_values.md | 1 - .../map_values/map_values_anonymous.md | 1 - .../python/expression_impl/parquet.md | 1 - .../expression_impl/parquet/ParquetDataset.md | 6 -- .../parquet/calculate_parquet_values.md | 1 - .../create_expression_from_parquet_file.md | 1 - .../python/expression_impl/placeholder.md | 1 - .../create_expression_from_schema.md | 1 - .../get_placeholder_paths_from_graph.md | 1 - .../python/expression_impl/project.md | 1 - .../python/expression_impl/project/project.md | 1 - .../python/expression_impl/promote.md | 1 - .../promote/PromoteChildExpression.md | 5 -- .../promote/PromoteExpression.md | 5 -- .../expression_impl/promote_and_broadcast.md | 1 - .../promote_and_broadcast.md | 1 - .../api_docs/python/expression_impl/proto.md | 1 - .../expression_impl/proto/DescriptorPool.md | 5 -- .../proto/FileDescriptorSet.md | 3 - ...ate_expression_from_file_descriptor_set.md | 1 - .../proto/create_expression_from_proto.md | 1 - .../proto/create_transformed_field.md | 1 - .../api_docs/python/expression_impl/reroot.md | 1 - .../python/expression_impl/reroot/reroot.md | 1 - g3doc/api_docs/python/expression_impl/size.md | 1 - .../expression_impl/size/SizeExpression.md | 5 -- .../python/expression_impl/size/has.md | 1 - .../python/expression_impl/size/size.md | 1 - .../expression_impl/size/size_anonymous.md | 1 - .../expression_impl/slice_expression.md | 1 - .../slice_expression/slice_expression.md | 1 - g3doc/api_docs/python/s2t.md | 1 - g3doc/api_docs/python/s2t/ChildNodeTensor.md | 5 -- g3doc/api_docs/python/s2t/Expression.md | 5 -- g3doc/api_docs/python/s2t/LeafNodeTensor.md | 5 -- g3doc/api_docs/python/s2t/Path.md | 4 -- g3doc/api_docs/python/s2t/Prensor.md | 5 -- g3doc/api_docs/python/s2t/RootNodeTensor.md | 5 -- g3doc/api_docs/python/s2t/all_symbols.md | 2 +- .../api_docs/python/s2t/calculate_prensors.md | 1 - .../s2t/calculate_prensors_with_graph.md | 1 - ...ate_expression_from_file_descriptor_set.md | 1 - .../s2t/create_expression_from_prensor.md | 1 - .../s2t/create_expression_from_proto.md | 1 - g3doc/api_docs/python/s2t/create_path.md | 1 - .../create_prensor_from_descendant_nodes.md | 1 - .../api_docs/python/s2t/get_ragged_tensor.md | 1 - .../api_docs/python/s2t/get_ragged_tensors.md | 1 - .../api_docs/python/s2t/get_sparse_tensor.md | 1 - .../api_docs/python/s2t/get_sparse_tensors.md | 1 - pyproject.toml | 0 setup.py | 3 +- struct2tensor/__init__.py | 58 +++++++++------- .../__pycache__/__init__.cpython-311.pyc | Bin 0 -> 1948 bytes .../__pycache__/calculate.cpython-311.pyc | Bin 0 -> 27468 bytes .../calculate_options.cpython-311.pyc | Bin 0 -> 2965 bytes ...alculate_with_source_paths.cpython-311.pyc | Bin 0 -> 6595 bytes .../__pycache__/expression.cpython-311.pyc | Bin 0 -> 33258 bytes .../expression_add.cpython-311.pyc | Bin 0 -> 10636 bytes .../__pycache__/path.cpython-311.pyc | Bin 0 -> 19963 bytes .../__pycache__/prensor.cpython-311.pyc | Bin 0 -> 45082 bytes struct2tensor/benchmarks/ops_benchmark.py | 4 +- .../benchmarks/struct2tensor_benchmark.py | 64 +++++++++--------- .../struct2tensor_benchmark_util.py | 10 +-- struct2tensor/calculate.py | 21 ++---- struct2tensor/calculate_options.py | 2 +- struct2tensor/calculate_test.py | 28 ++++---- struct2tensor/calculate_with_source_paths.py | 22 +++--- .../calculate_with_source_paths_test.py | 24 +++---- struct2tensor/create_expression.py | 8 +-- struct2tensor/create_expression_test.py | 9 ++- struct2tensor/expression.py | 29 ++++---- struct2tensor/expression_add.py | 15 ++-- struct2tensor/expression_add_test.py | 9 +-- struct2tensor/expression_impl/__init__.py | 36 +++++----- .../__pycache__/__init__.cpython-311.pyc | Bin 0 -> 1314 bytes .../__pycache__/apply_schema.cpython-311.pyc | Bin 0 -> 9657 bytes .../__pycache__/broadcast.cpython-311.pyc | Bin 0 -> 14908 bytes .../__pycache__/depth_limit.cpython-311.pyc | Bin 0 -> 4244 bytes .../filter_expression.cpython-311.pyc | Bin 0 -> 20909 bytes .../__pycache__/index.cpython-311.pyc | Bin 0 -> 10422 bytes .../__pycache__/map_prensor.cpython-311.pyc | Bin 0 -> 18599 bytes .../map_prensor_to_prensor.cpython-311.pyc | Bin 0 -> 22795 bytes .../__pycache__/map_values.cpython-311.pyc | Bin 0 -> 8911 bytes .../__pycache__/parquet.cpython-311.pyc | Bin 0 -> 28931 bytes .../__pycache__/placeholder.cpython-311.pyc | Bin 0 -> 12019 bytes .../__pycache__/project.cpython-311.pyc | Bin 0 -> 6338 bytes .../__pycache__/size.cpython-311.pyc | Bin 0 -> 8041 bytes struct2tensor/expression_impl/apply_schema.py | 11 +-- .../expression_impl/apply_schema_test.py | 7 +- struct2tensor/expression_impl/broadcast.py | 9 +-- .../expression_impl/broadcast_test.py | 14 ++-- struct2tensor/expression_impl/depth_limit.py | 5 +- .../expression_impl/depth_limit_test.py | 3 +- .../expression_impl/filter_expression.py | 10 +-- .../expression_impl/filter_expression_test.py | 28 ++++---- struct2tensor/expression_impl/index.py | 12 ++-- struct2tensor/expression_impl/index_test.py | 23 +++---- struct2tensor/expression_impl/map_prensor.py | 16 ++--- .../expression_impl/map_prensor_test.py | 17 ++--- .../expression_impl/map_prensor_to_prensor.py | 20 ++---- .../map_prensor_to_prensor_test.py | 21 +++--- struct2tensor/expression_impl/map_values.py | 9 +-- .../expression_impl/map_values_test.py | 16 ++--- struct2tensor/expression_impl/parquet.py | 12 +--- struct2tensor/expression_impl/parquet_test.py | 14 ++-- .../expression_impl/parse_message_level_ex.py | 8 +-- .../parse_message_level_ex_test.py | 15 ++-- struct2tensor/expression_impl/placeholder.py | 9 +-- .../expression_impl/placeholder_test.py | 16 ++--- struct2tensor/expression_impl/project.py | 10 +-- struct2tensor/expression_impl/project_test.py | 4 +- struct2tensor/expression_impl/promote.py | 10 +-- .../expression_impl/promote_and_broadcast.py | 8 +-- .../promote_and_broadcast_test.py | 9 ++- struct2tensor/expression_impl/promote_test.py | 14 ++-- struct2tensor/expression_impl/proto.py | 49 +++++++------- struct2tensor/expression_impl/proto_test.py | 28 ++++---- .../expression_impl/proto_test_util.py | 4 +- .../raw_parquet_dataset_test.py | 7 +- struct2tensor/expression_impl/reroot.py | 14 ++-- struct2tensor/expression_impl/reroot_test.py | 18 ++--- struct2tensor/expression_impl/size.py | 11 ++- struct2tensor/expression_impl/size_test.py | 14 ++-- .../expression_impl/slice_expression.py | 12 ++-- .../expression_impl/slice_expression_test.py | 17 +++-- struct2tensor/expression_test.py | 16 ++--- .../ops/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 185 bytes .../file_descriptor_set.cpython-311.pyc | Bin 0 -> 6135 bytes ...ptor_set_test.cpython-311-pytest-8.3.3.pyc | Bin 0 -> 2030 bytes .../gen_decode_proto_map_op.cpython-311.pyc | Bin 0 -> 695 bytes .../gen_parquet_dataset.cpython-311.pyc | Bin 0 -> 640 bytes .../struct2tensor_ops.cpython-311.pyc | Bin 0 -> 14312 bytes ...nsor_ops_test.cpython-311-pytest-8.3.3.pyc | Bin 0 -> 49897 bytes struct2tensor/ops/file_descriptor_set.py | 8 ++- struct2tensor/ops/file_descriptor_set_test.py | 8 +-- struct2tensor/ops/struct2tensor_ops.py | 32 +++++---- struct2tensor/ops/struct2tensor_ops_test.py | 48 ++++++------- struct2tensor/path.py | 5 +- struct2tensor/path_test.py | 17 ++--- struct2tensor/prensor.py | 40 ++++++----- struct2tensor/prensor_test.py | 10 +-- struct2tensor/prensor_to_structured_tensor.py | 13 ++-- .../prensor_to_structured_tensor_test.py | 40 ++++++----- struct2tensor/prensor_util.py | 8 +-- struct2tensor/prensor_value.py | 41 ++++++----- struct2tensor/prensor_value_test.py | 10 +-- struct2tensor/struct2tensor_module_test.py | 2 +- struct2tensor/structured_tensor_to_prensor.py | 15 ++-- .../structured_tensor_to_prensor_test.py | 16 +++-- struct2tensor/test/expression_test_util.py | 13 ++-- struct2tensor/test/prensor_test_util.py | 5 +- struct2tensor/tools/build_docs.py | 10 ++- .../tools/docker_build/build_manylinux.sh | 1 - 177 files changed, 571 insertions(+), 782 deletions(-) create mode 100644 .gitignore create mode 100644 pyproject.toml create mode 100644 struct2tensor/__pycache__/__init__.cpython-311.pyc create mode 100644 struct2tensor/__pycache__/calculate.cpython-311.pyc create mode 100644 struct2tensor/__pycache__/calculate_options.cpython-311.pyc create mode 100644 struct2tensor/__pycache__/calculate_with_source_paths.cpython-311.pyc create mode 100644 struct2tensor/__pycache__/expression.cpython-311.pyc create mode 100644 struct2tensor/__pycache__/expression_add.cpython-311.pyc create mode 100644 struct2tensor/__pycache__/path.cpython-311.pyc create mode 100644 struct2tensor/__pycache__/prensor.cpython-311.pyc create mode 100644 struct2tensor/expression_impl/__pycache__/__init__.cpython-311.pyc create mode 100644 struct2tensor/expression_impl/__pycache__/apply_schema.cpython-311.pyc create mode 100644 struct2tensor/expression_impl/__pycache__/broadcast.cpython-311.pyc create mode 100644 struct2tensor/expression_impl/__pycache__/depth_limit.cpython-311.pyc create mode 100644 struct2tensor/expression_impl/__pycache__/filter_expression.cpython-311.pyc create mode 100644 struct2tensor/expression_impl/__pycache__/index.cpython-311.pyc create mode 100644 struct2tensor/expression_impl/__pycache__/map_prensor.cpython-311.pyc create mode 100644 struct2tensor/expression_impl/__pycache__/map_prensor_to_prensor.cpython-311.pyc create mode 100644 struct2tensor/expression_impl/__pycache__/map_values.cpython-311.pyc create mode 100644 struct2tensor/expression_impl/__pycache__/parquet.cpython-311.pyc create mode 100644 struct2tensor/expression_impl/__pycache__/placeholder.cpython-311.pyc create mode 100644 struct2tensor/expression_impl/__pycache__/project.cpython-311.pyc create mode 100644 struct2tensor/expression_impl/__pycache__/size.cpython-311.pyc create mode 100644 struct2tensor/ops/__pycache__/__init__.cpython-311.pyc create mode 100644 struct2tensor/ops/__pycache__/file_descriptor_set.cpython-311.pyc create mode 100644 struct2tensor/ops/__pycache__/file_descriptor_set_test.cpython-311-pytest-8.3.3.pyc create mode 100644 struct2tensor/ops/__pycache__/gen_decode_proto_map_op.cpython-311.pyc create mode 100644 struct2tensor/ops/__pycache__/gen_parquet_dataset.cpython-311.pyc create mode 100644 struct2tensor/ops/__pycache__/struct2tensor_ops.cpython-311.pyc create mode 100644 struct2tensor/ops/__pycache__/struct2tensor_ops_test.cpython-311-pytest-8.3.3.pyc diff --git a/.bazelversion b/.bazelversion index 4be2c72..f22d756 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -6.5.0 \ No newline at end of file +6.5.0 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4549f5e --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +dist/ +bazel-* diff --git a/LICENSE b/LICENSE index f49a4e1..261eeb9 100644 --- a/LICENSE +++ b/LICENSE @@ -198,4 +198,4 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file + limitations under the License. diff --git a/g3doc/api_docs/python/expression_impl.md b/g3doc/api_docs/python/expression_impl.md index 40da37a..dea6a98 100644 --- a/g3doc/api_docs/python/expression_impl.md +++ b/g3doc/api_docs/python/expression_impl.md @@ -67,4 +67,3 @@ s2t.expression_impl.apply_schema [`size`](./expression_impl/size.md) module: Functions for creating new size or has expression. [`slice_expression`](./expression_impl/slice_expression.md) module: Implementation of slice. - diff --git a/g3doc/api_docs/python/expression_impl/all_symbols.md b/g3doc/api_docs/python/expression_impl/all_symbols.md index efd3051..5556d83 100644 --- a/g3doc/api_docs/python/expression_impl/all_symbols.md +++ b/g3doc/api_docs/python/expression_impl/all_symbols.md @@ -64,4 +64,4 @@ * expression_impl.size.size_anonymous * expression_impl.slice_expression * expression_impl.slice_expression.IndexValue -* expression_impl.slice_expression.slice_expression \ No newline at end of file +* expression_impl.slice_expression.slice_expression diff --git a/g3doc/api_docs/python/expression_impl/apply_schema.md b/g3doc/api_docs/python/expression_impl/apply_schema.md index 3ae89cd..354ceea 100644 --- a/g3doc/api_docs/python/expression_impl/apply_schema.md +++ b/g3doc/api_docs/python/expression_impl/apply_schema.md @@ -57,4 +57,3 @@ my_new_schema has semantically identical information on the fields as my_schema. ## Functions [`apply_schema(...)`](../expression_impl/apply_schema/apply_schema.md) - diff --git a/g3doc/api_docs/python/expression_impl/broadcast.md b/g3doc/api_docs/python/expression_impl/broadcast.md index 5e7565c..43b7351 100644 --- a/g3doc/api_docs/python/expression_impl/broadcast.md +++ b/g3doc/api_docs/python/expression_impl/broadcast.md @@ -95,4 +95,3 @@ session: { [`broadcast(...)`](../expression_impl/broadcast/broadcast.md) [`broadcast_anonymous(...)`](../expression_impl/broadcast/broadcast_anonymous.md) - diff --git a/g3doc/api_docs/python/expression_impl/depth_limit.md b/g3doc/api_docs/python/expression_impl/depth_limit.md index 7ba27d2..22de1ba 100644 --- a/g3doc/api_docs/python/expression_impl/depth_limit.md +++ b/g3doc/api_docs/python/expression_impl/depth_limit.md @@ -49,4 +49,3 @@ You get: ## Functions [`limit_depth(...)`](../expression_impl/depth_limit/limit_depth.md): Limit the depth to nodes k steps from expr. - diff --git a/g3doc/api_docs/python/expression_impl/filter_expression.md b/g3doc/api_docs/python/expression_impl/filter_expression.md index d9018a6..d16af36 100644 --- a/g3doc/api_docs/python/expression_impl/filter_expression.md +++ b/g3doc/api_docs/python/expression_impl/filter_expression.md @@ -73,4 +73,3 @@ root_2 = filter_expression.filter_by_child( [`filter_by_child(...)`](../expression_impl/filter_expression/filter_by_child.md): Filter an expression by an optional boolean child field. [`filter_by_sibling(...)`](../expression_impl/filter_expression/filter_by_sibling.md): Filter an expression by its sibling. - diff --git a/g3doc/api_docs/python/expression_impl/filter_expression/filter_by_child.md b/g3doc/api_docs/python/expression_impl/filter_expression/filter_by_child.md index 5427623..1b3dc9c 100644 --- a/g3doc/api_docs/python/expression_impl/filter_expression/filter_by_child.md +++ b/g3doc/api_docs/python/expression_impl/filter_expression/filter_by_child.md @@ -87,4 +87,3 @@ The new root expression. - diff --git a/g3doc/api_docs/python/expression_impl/filter_expression/filter_by_sibling.md b/g3doc/api_docs/python/expression_impl/filter_expression/filter_by_sibling.md index 1e3808d..8303c90 100644 --- a/g3doc/api_docs/python/expression_impl/filter_expression/filter_by_sibling.md +++ b/g3doc/api_docs/python/expression_impl/filter_expression/filter_by_sibling.md @@ -89,4 +89,3 @@ a new root. - diff --git a/g3doc/api_docs/python/expression_impl/index.md b/g3doc/api_docs/python/expression_impl/index.md index 822383a..67ba68e 100644 --- a/g3doc/api_docs/python/expression_impl/index.md +++ b/g3doc/api_docs/python/expression_impl/index.md @@ -133,4 +133,3 @@ take little memory or CPU. [`get_index_from_end(...)`](../expression_impl/index/get_index_from_end.md): Gets the number of steps from the end of the array. [`get_positional_index(...)`](../expression_impl/index/get_positional_index.md): Gets the positional index. - diff --git a/g3doc/api_docs/python/expression_impl/index/get_index_from_end.md b/g3doc/api_docs/python/expression_impl/index/get_index_from_end.md index 6ab17db..a70ae01 100644 --- a/g3doc/api_docs/python/expression_impl/index/get_index_from_end.md +++ b/g3doc/api_docs/python/expression_impl/index/get_index_from_end.md @@ -79,4 +79,3 @@ The new expression and the new path as a pair. - diff --git a/g3doc/api_docs/python/expression_impl/index/get_positional_index.md b/g3doc/api_docs/python/expression_impl/index/get_positional_index.md index 43b4e2b..95d8486 100644 --- a/g3doc/api_docs/python/expression_impl/index/get_positional_index.md +++ b/g3doc/api_docs/python/expression_impl/index/get_positional_index.md @@ -79,4 +79,3 @@ The new expression and the new path as a pair. - diff --git a/g3doc/api_docs/python/expression_impl/map_prensor.md b/g3doc/api_docs/python/expression_impl/map_prensor.md index 9806d8e..03e5ea1 100644 --- a/g3doc/api_docs/python/expression_impl/map_prensor.md +++ b/g3doc/api_docs/python/expression_impl/map_prensor.md @@ -114,4 +114,3 @@ session: { [`map_ragged_tensor(...)`](../expression_impl/map_prensor/map_ragged_tensor.md): Map a ragged tensor. [`map_sparse_tensor(...)`](../expression_impl/map_prensor/map_sparse_tensor.md): Maps a sparse tensor. - diff --git a/g3doc/api_docs/python/expression_impl/map_prensor/map_ragged_tensor.md b/g3doc/api_docs/python/expression_impl/map_prensor/map_ragged_tensor.md index 9b12225..8bb79e9 100644 --- a/g3doc/api_docs/python/expression_impl/map_prensor/map_ragged_tensor.md +++ b/g3doc/api_docs/python/expression_impl/map_prensor/map_ragged_tensor.md @@ -112,4 +112,3 @@ root_path.get_child(new_field_name), with the result of the operation. - diff --git a/g3doc/api_docs/python/expression_impl/map_prensor/map_sparse_tensor.md b/g3doc/api_docs/python/expression_impl/map_prensor/map_sparse_tensor.md index 2f32e89..efe541b 100644 --- a/g3doc/api_docs/python/expression_impl/map_prensor/map_sparse_tensor.md +++ b/g3doc/api_docs/python/expression_impl/map_prensor/map_sparse_tensor.md @@ -112,4 +112,3 @@ root_path.get_child(new_field_name), with the result of the operation. - diff --git a/g3doc/api_docs/python/expression_impl/map_prensor_to_prensor.md b/g3doc/api_docs/python/expression_impl/map_prensor_to_prensor.md index ce3164f..f1c5d06 100644 --- a/g3doc/api_docs/python/expression_impl/map_prensor_to_prensor.md +++ b/g3doc/api_docs/python/expression_impl/map_prensor_to_prensor.md @@ -86,4 +86,3 @@ foo bar foo2 bar2 [`create_schema(...)`](../expression_impl/map_prensor_to_prensor/create_schema.md): Create a schema recursively. [`map_prensor_to_prensor(...)`](../expression_impl/map_prensor_to_prensor/map_prensor_to_prensor.md): Maps an expression to a prensor, and merges that prensor. - diff --git a/g3doc/api_docs/python/expression_impl/map_prensor_to_prensor/Schema.md b/g3doc/api_docs/python/expression_impl/map_prensor_to_prensor/Schema.md index 065306b..4380270 100644 --- a/g3doc/api_docs/python/expression_impl/map_prensor_to_prensor/Schema.md +++ b/g3doc/api_docs/python/expression_impl/map_prensor_to_prensor/Schema.md @@ -145,9 +145,3 @@ child schemas. - - - - - - diff --git a/g3doc/api_docs/python/expression_impl/map_prensor_to_prensor/create_schema.md b/g3doc/api_docs/python/expression_impl/map_prensor_to_prensor/create_schema.md index d249def..81217e5 100644 --- a/g3doc/api_docs/python/expression_impl/map_prensor_to_prensor/create_schema.md +++ b/g3doc/api_docs/python/expression_impl/map_prensor_to_prensor/create_schema.md @@ -95,4 +95,3 @@ a new Schema represented by the inputs. - diff --git a/g3doc/api_docs/python/expression_impl/map_prensor_to_prensor/map_prensor_to_prensor.md b/g3doc/api_docs/python/expression_impl/map_prensor_to_prensor/map_prensor_to_prensor.md index 9d6f5c5..caf69b4 100644 --- a/g3doc/api_docs/python/expression_impl/map_prensor_to_prensor/map_prensor_to_prensor.md +++ b/g3doc/api_docs/python/expression_impl/map_prensor_to_prensor/map_prensor_to_prensor.md @@ -126,4 +126,3 @@ A new expression where the prensor is merged. - diff --git a/g3doc/api_docs/python/expression_impl/map_values.md b/g3doc/api_docs/python/expression_impl/map_values.md index 7553266..176b113 100644 --- a/g3doc/api_docs/python/expression_impl/map_values.md +++ b/g3doc/api_docs/python/expression_impl/map_values.md @@ -36,4 +36,3 @@ Note that the operations are on 1-D tensors (as opposed to scalars). [`map_values(...)`](../expression_impl/map_values/map_values.md): Map field into a new sibling. [`map_values_anonymous(...)`](../expression_impl/map_values/map_values_anonymous.md): Map field into a new sibling. - diff --git a/g3doc/api_docs/python/expression_impl/map_values/map_many_values.md b/g3doc/api_docs/python/expression_impl/map_values/map_many_values.md index 5ab7f1f..5a295a6 100644 --- a/g3doc/api_docs/python/expression_impl/map_values/map_many_values.md +++ b/g3doc/api_docs/python/expression_impl/map_values/map_many_values.md @@ -103,4 +103,3 @@ The new expression and the new path as a pair. - diff --git a/g3doc/api_docs/python/expression_impl/map_values/map_values.md b/g3doc/api_docs/python/expression_impl/map_values/map_values.md index 7b2f2c1..03d528c 100644 --- a/g3doc/api_docs/python/expression_impl/map_values/map_values.md +++ b/g3doc/api_docs/python/expression_impl/map_values/map_values.md @@ -94,4 +94,3 @@ The new expression. - diff --git a/g3doc/api_docs/python/expression_impl/map_values/map_values_anonymous.md b/g3doc/api_docs/python/expression_impl/map_values/map_values_anonymous.md index 6a8e712..f3a1840 100644 --- a/g3doc/api_docs/python/expression_impl/map_values/map_values_anonymous.md +++ b/g3doc/api_docs/python/expression_impl/map_values/map_values_anonymous.md @@ -86,4 +86,3 @@ The new expression and the new path as a pair. - diff --git a/g3doc/api_docs/python/expression_impl/parquet.md b/g3doc/api_docs/python/expression_impl/parquet.md index 2a3cf40..51cc719 100644 --- a/g3doc/api_docs/python/expression_impl/parquet.md +++ b/g3doc/api_docs/python/expression_impl/parquet.md @@ -47,4 +47,3 @@ Apache Parquet Dataset. [`calculate_parquet_values(...)`](../expression_impl/parquet/calculate_parquet_values.md): Calculates expressions and returns a parquet dataset. [`create_expression_from_parquet_file(...)`](../expression_impl/parquet/create_expression_from_parquet_file.md): Creates a placeholder expression from a parquet file. - diff --git a/g3doc/api_docs/python/expression_impl/parquet/ParquetDataset.md b/g3doc/api_docs/python/expression_impl/parquet/ParquetDataset.md index b8261b6..e3c95f4 100644 --- a/g3doc/api_docs/python/expression_impl/parquet/ParquetDataset.md +++ b/g3doc/api_docs/python/expression_impl/parquet/ParquetDataset.md @@ -3447,9 +3447,3 @@ execution is not enabled. - - - - - - diff --git a/g3doc/api_docs/python/expression_impl/parquet/calculate_parquet_values.md b/g3doc/api_docs/python/expression_impl/parquet/calculate_parquet_values.md index 393b3b9..620b6cd 100644 --- a/g3doc/api_docs/python/expression_impl/parquet/calculate_parquet_values.md +++ b/g3doc/api_docs/python/expression_impl/parquet/calculate_parquet_values.md @@ -93,4 +93,3 @@ A parquet dataset. - diff --git a/g3doc/api_docs/python/expression_impl/parquet/create_expression_from_parquet_file.md b/g3doc/api_docs/python/expression_impl/parquet/create_expression_from_parquet_file.md index d8f3f9e..d3405d8 100644 --- a/g3doc/api_docs/python/expression_impl/parquet/create_expression_from_parquet_file.md +++ b/g3doc/api_docs/python/expression_impl/parquet/create_expression_from_parquet_file.md @@ -62,4 +62,3 @@ graph. - diff --git a/g3doc/api_docs/python/expression_impl/placeholder.md b/g3doc/api_docs/python/expression_impl/placeholder.md index a39e524..4ddf5fb 100644 --- a/g3doc/api_docs/python/expression_impl/placeholder.md +++ b/g3doc/api_docs/python/expression_impl/placeholder.md @@ -47,4 +47,3 @@ result = calculate.calculate_values([new_exp], [`create_expression_from_schema(...)`](../expression_impl/placeholder/create_expression_from_schema.md): Creates a placeholder expression from a parquet schema. [`get_placeholder_paths_from_graph(...)`](../expression_impl/placeholder/get_placeholder_paths_from_graph.md): Gets all placeholder paths from an expression graph. - diff --git a/g3doc/api_docs/python/expression_impl/placeholder/create_expression_from_schema.md b/g3doc/api_docs/python/expression_impl/placeholder/create_expression_from_schema.md index bf03a42..42b6c5d 100644 --- a/g3doc/api_docs/python/expression_impl/placeholder/create_expression_from_schema.md +++ b/g3doc/api_docs/python/expression_impl/placeholder/create_expression_from_schema.md @@ -63,4 +63,3 @@ graph. - diff --git a/g3doc/api_docs/python/expression_impl/placeholder/get_placeholder_paths_from_graph.md b/g3doc/api_docs/python/expression_impl/placeholder/get_placeholder_paths_from_graph.md index a7f008e..bfd7a79 100644 --- a/g3doc/api_docs/python/expression_impl/placeholder/get_placeholder_paths_from_graph.md +++ b/g3doc/api_docs/python/expression_impl/placeholder/get_placeholder_paths_from_graph.md @@ -63,4 +63,3 @@ a list of paths of placeholder expressions - diff --git a/g3doc/api_docs/python/expression_impl/project.md b/g3doc/api_docs/python/expression_impl/project.md index 8fe4c1a..be38db8 100644 --- a/g3doc/api_docs/python/expression_impl/project.md +++ b/g3doc/api_docs/python/expression_impl/project.md @@ -41,4 +41,3 @@ prensor_result now has two paths, "foo.bar" and "x.y". ## Functions [`project(...)`](../expression_impl/project/project.md): select a subtree. - diff --git a/g3doc/api_docs/python/expression_impl/project/project.md b/g3doc/api_docs/python/expression_impl/project/project.md index f8ac3e2..18f5158 100644 --- a/g3doc/api_docs/python/expression_impl/project/project.md +++ b/g3doc/api_docs/python/expression_impl/project/project.md @@ -72,4 +72,3 @@ A projected expression. - diff --git a/g3doc/api_docs/python/expression_impl/promote.md b/g3doc/api_docs/python/expression_impl/promote.md index 7053010..a608324 100644 --- a/g3doc/api_docs/python/expression_impl/promote.md +++ b/g3doc/api_docs/python/expression_impl/promote.md @@ -115,4 +115,3 @@ session: { [`promote(...)`](../expression_impl/promote/promote.md): Promote a path to be a child of its grandparent, and give it a name. [`promote_anonymous(...)`](../expression_impl/promote/promote_anonymous.md): Promote a path to be a new anonymous child of its grandparent. - diff --git a/g3doc/api_docs/python/expression_impl/promote/PromoteChildExpression.md b/g3doc/api_docs/python/expression_impl/promote/PromoteChildExpression.md index 540f290..677e26d 100644 --- a/g3doc/api_docs/python/expression_impl/promote/PromoteChildExpression.md +++ b/g3doc/api_docs/python/expression_impl/promote/PromoteChildExpression.md @@ -1037,8 +1037,3 @@ Boolean of equality of two expressions - - - - - diff --git a/g3doc/api_docs/python/expression_impl/promote/PromoteExpression.md b/g3doc/api_docs/python/expression_impl/promote/PromoteExpression.md index b1f8e32..69e476c 100644 --- a/g3doc/api_docs/python/expression_impl/promote/PromoteExpression.md +++ b/g3doc/api_docs/python/expression_impl/promote/PromoteExpression.md @@ -1034,8 +1034,3 @@ Boolean of equality of two expressions - - - - - diff --git a/g3doc/api_docs/python/expression_impl/promote_and_broadcast.md b/g3doc/api_docs/python/expression_impl/promote_and_broadcast.md index ec7c211..8e6ee46 100644 --- a/g3doc/api_docs/python/expression_impl/promote_and_broadcast.md +++ b/g3doc/api_docs/python/expression_impl/promote_and_broadcast.md @@ -123,4 +123,3 @@ session: { [`promote_and_broadcast(...)`](../expression_impl/promote_and_broadcast/promote_and_broadcast.md): Promote and broadcast a set of paths to a particular location. [`promote_and_broadcast_anonymous(...)`](../expression_impl/promote_and_broadcast/promote_and_broadcast_anonymous.md): Promotes then broadcasts the origin until its parent is new_parent. - diff --git a/g3doc/api_docs/python/expression_impl/promote_and_broadcast/promote_and_broadcast.md b/g3doc/api_docs/python/expression_impl/promote_and_broadcast/promote_and_broadcast.md index 212fe1f..2626e1a 100644 --- a/g3doc/api_docs/python/expression_impl/promote_and_broadcast/promote_and_broadcast.md +++ b/g3doc/api_docs/python/expression_impl/promote_and_broadcast/promote_and_broadcast.md @@ -78,4 +78,3 @@ until they are children of dest_path_parent. - diff --git a/g3doc/api_docs/python/expression_impl/proto.md b/g3doc/api_docs/python/expression_impl/proto.md index 9d0016b..b09418d 100644 --- a/g3doc/api_docs/python/expression_impl/proto.md +++ b/g3doc/api_docs/python/expression_impl/proto.md @@ -48,4 +48,3 @@ for its children. [`ProtoExpression`](../expression_impl/proto/ProtoExpression.md) [`TransformFn`](../expression_impl/proto/TransformFn.md) - diff --git a/g3doc/api_docs/python/expression_impl/proto/DescriptorPool.md b/g3doc/api_docs/python/expression_impl/proto/DescriptorPool.md index f5e36b6..e04a088 100644 --- a/g3doc/api_docs/python/expression_impl/proto/DescriptorPool.md +++ b/g3doc/api_docs/python/expression_impl/proto/DescriptorPool.md @@ -806,8 +806,3 @@ if the service cannot be found in the pool. - - - - - diff --git a/g3doc/api_docs/python/expression_impl/proto/FileDescriptorSet.md b/g3doc/api_docs/python/expression_impl/proto/FileDescriptorSet.md index 373d76c..d96631e 100644 --- a/g3doc/api_docs/python/expression_impl/proto/FileDescriptorSet.md +++ b/g3doc/api_docs/python/expression_impl/proto/FileDescriptorSet.md @@ -36,6 +36,3 @@ A ProtocolMessage - - - diff --git a/g3doc/api_docs/python/expression_impl/proto/create_expression_from_file_descriptor_set.md b/g3doc/api_docs/python/expression_impl/proto/create_expression_from_file_descriptor_set.md index f0c2413..05744af 100644 --- a/g3doc/api_docs/python/expression_impl/proto/create_expression_from_file_descriptor_set.md +++ b/g3doc/api_docs/python/expression_impl/proto/create_expression_from_file_descriptor_set.md @@ -90,4 +90,3 @@ An expression. - diff --git a/g3doc/api_docs/python/expression_impl/proto/create_expression_from_proto.md b/g3doc/api_docs/python/expression_impl/proto/create_expression_from_proto.md index 4a2744c..02e5c1a 100644 --- a/g3doc/api_docs/python/expression_impl/proto/create_expression_from_proto.md +++ b/g3doc/api_docs/python/expression_impl/proto/create_expression_from_proto.md @@ -78,4 +78,3 @@ An expression. - diff --git a/g3doc/api_docs/python/expression_impl/proto/create_transformed_field.md b/g3doc/api_docs/python/expression_impl/proto/create_transformed_field.md index 7555fb9..820ada0 100644 --- a/g3doc/api_docs/python/expression_impl/proto/create_transformed_field.md +++ b/g3doc/api_docs/python/expression_impl/proto/create_transformed_field.md @@ -123,4 +123,3 @@ if the source path is not a proto message field. - diff --git a/g3doc/api_docs/python/expression_impl/reroot.md b/g3doc/api_docs/python/expression_impl/reroot.md index e1bf645..318fc73 100644 --- a/g3doc/api_docs/python/expression_impl/reroot.md +++ b/g3doc/api_docs/python/expression_impl/reroot.md @@ -32,4 +32,3 @@ original proto. [`create_proto_index_field(...)`](../expression_impl/reroot/create_proto_index_field.md) [`reroot(...)`](../expression_impl/reroot/reroot.md): Reroot to a new path, maintaining a input proto index. - diff --git a/g3doc/api_docs/python/expression_impl/reroot/reroot.md b/g3doc/api_docs/python/expression_impl/reroot/reroot.md index 7b99d93..098c356 100644 --- a/g3doc/api_docs/python/expression_impl/reroot/reroot.md +++ b/g3doc/api_docs/python/expression_impl/reroot/reroot.md @@ -71,4 +71,3 @@ the new root. - diff --git a/g3doc/api_docs/python/expression_impl/size.md b/g3doc/api_docs/python/expression_impl/size.md index 94ec2da..77c413d 100644 --- a/g3doc/api_docs/python/expression_impl/size.md +++ b/g3doc/api_docs/python/expression_impl/size.md @@ -50,4 +50,3 @@ is always present, and is true if there are one or more bar in foo. [`size(...)`](../expression_impl/size/size.md): Get the size of a field as a new sibling field. [`size_anonymous(...)`](../expression_impl/size/size_anonymous.md): Calculate the size of a field, and store it as an anonymous sibling. - diff --git a/g3doc/api_docs/python/expression_impl/size/SizeExpression.md b/g3doc/api_docs/python/expression_impl/size/SizeExpression.md index 3adb5e0..f8e720d 100644 --- a/g3doc/api_docs/python/expression_impl/size/SizeExpression.md +++ b/g3doc/api_docs/python/expression_impl/size/SizeExpression.md @@ -1036,8 +1036,3 @@ Boolean of equality of two expressions - - - - - diff --git a/g3doc/api_docs/python/expression_impl/size/has.md b/g3doc/api_docs/python/expression_impl/size/has.md index 3dc9d63..584ccb5 100644 --- a/g3doc/api_docs/python/expression_impl/size/has.md +++ b/g3doc/api_docs/python/expression_impl/size/has.md @@ -77,4 +77,3 @@ The new expression. - diff --git a/g3doc/api_docs/python/expression_impl/size/size.md b/g3doc/api_docs/python/expression_impl/size/size.md index 819c9e3..7034f51 100644 --- a/g3doc/api_docs/python/expression_impl/size/size.md +++ b/g3doc/api_docs/python/expression_impl/size/size.md @@ -77,4 +77,3 @@ The new expression. - diff --git a/g3doc/api_docs/python/expression_impl/size/size_anonymous.md b/g3doc/api_docs/python/expression_impl/size/size_anonymous.md index 792d371..7f6189c 100644 --- a/g3doc/api_docs/python/expression_impl/size/size_anonymous.md +++ b/g3doc/api_docs/python/expression_impl/size/size_anonymous.md @@ -69,4 +69,3 @@ The new expression and the new field as a pair. - diff --git a/g3doc/api_docs/python/expression_impl/slice_expression.md b/g3doc/api_docs/python/expression_impl/slice_expression.md index 9ba697d..67f6f8d 100644 --- a/g3doc/api_docs/python/expression_impl/slice_expression.md +++ b/g3doc/api_docs/python/expression_impl/slice_expression.md @@ -132,4 +132,3 @@ result_2 = [{ ## Type Aliases [`IndexValue`](../expression_impl/slice_expression/IndexValue.md) - diff --git a/g3doc/api_docs/python/expression_impl/slice_expression/slice_expression.md b/g3doc/api_docs/python/expression_impl/slice_expression/slice_expression.md index 9de0d23..7e6a205 100644 --- a/g3doc/api_docs/python/expression_impl/slice_expression/slice_expression.md +++ b/g3doc/api_docs/python/expression_impl/slice_expression/slice_expression.md @@ -95,4 +95,3 @@ A new root expression. - diff --git a/g3doc/api_docs/python/s2t.md b/g3doc/api_docs/python/s2t.md index 6f5e8c6..547f872 100644 --- a/g3doc/api_docs/python/s2t.md +++ b/g3doc/api_docs/python/s2t.md @@ -75,4 +75,3 @@ Import core names for struct2tensor. [`NodeTensor`](./s2t/NodeTensor.md) [`Step`](./s2t/Step.md) - diff --git a/g3doc/api_docs/python/s2t/ChildNodeTensor.md b/g3doc/api_docs/python/s2t/ChildNodeTensor.md index cc3d3c1..4c6ac42 100644 --- a/g3doc/api_docs/python/s2t/ChildNodeTensor.md +++ b/g3doc/api_docs/python/s2t/ChildNodeTensor.md @@ -132,8 +132,3 @@ A tensor of positional indices. - - - - - diff --git a/g3doc/api_docs/python/s2t/Expression.md b/g3doc/api_docs/python/s2t/Expression.md index dd32541..e2e5036 100644 --- a/g3doc/api_docs/python/s2t/Expression.md +++ b/g3doc/api_docs/python/s2t/Expression.md @@ -1095,8 +1095,3 @@ Boolean of equality of two expressions - - - - - diff --git a/g3doc/api_docs/python/s2t/LeafNodeTensor.md b/g3doc/api_docs/python/s2t/LeafNodeTensor.md index eb7bbb0..077465c 100644 --- a/g3doc/api_docs/python/s2t/LeafNodeTensor.md +++ b/g3doc/api_docs/python/s2t/LeafNodeTensor.md @@ -140,8 +140,3 @@ A tensor of positional indices. - - - - - diff --git a/g3doc/api_docs/python/s2t/Path.md b/g3doc/api_docs/python/s2t/Path.md index 6fa08b7..1b03ff5 100644 --- a/g3doc/api_docs/python/s2t/Path.md +++ b/g3doc/api_docs/python/s2t/Path.md @@ -334,7 +334,3 @@ Return self Return self!=value. - - - - diff --git a/g3doc/api_docs/python/s2t/Prensor.md b/g3doc/api_docs/python/s2t/Prensor.md index e26a368..948309f 100644 --- a/g3doc/api_docs/python/s2t/Prensor.md +++ b/g3doc/api_docs/python/s2t/Prensor.md @@ -377,8 +377,3 @@ A map from paths to sparse tensors. - - - - - diff --git a/g3doc/api_docs/python/s2t/RootNodeTensor.md b/g3doc/api_docs/python/s2t/RootNodeTensor.md index c38c929..6cfce4a 100644 --- a/g3doc/api_docs/python/s2t/RootNodeTensor.md +++ b/g3doc/api_docs/python/s2t/RootNodeTensor.md @@ -104,8 +104,3 @@ A tensor of positional indices. - - - - - diff --git a/g3doc/api_docs/python/s2t/all_symbols.md b/g3doc/api_docs/python/s2t/all_symbols.md index bd052bb..5ae4430 100644 --- a/g3doc/api_docs/python/s2t/all_symbols.md +++ b/g3doc/api_docs/python/s2t/all_symbols.md @@ -26,4 +26,4 @@ * s2t.get_ragged_tensor * s2t.get_ragged_tensors * s2t.get_sparse_tensor -* s2t.get_sparse_tensors \ No newline at end of file +* s2t.get_sparse_tensors diff --git a/g3doc/api_docs/python/s2t/calculate_prensors.md b/g3doc/api_docs/python/s2t/calculate_prensors.md index ee30dea..83375a7 100644 --- a/g3doc/api_docs/python/s2t/calculate_prensors.md +++ b/g3doc/api_docs/python/s2t/calculate_prensors.md @@ -78,4 +78,3 @@ a list of prensors. - diff --git a/g3doc/api_docs/python/s2t/calculate_prensors_with_graph.md b/g3doc/api_docs/python/s2t/calculate_prensors_with_graph.md index e1a05b3..1a327f9 100644 --- a/g3doc/api_docs/python/s2t/calculate_prensors_with_graph.md +++ b/g3doc/api_docs/python/s2t/calculate_prensors_with_graph.md @@ -80,4 +80,3 @@ a list of prensors, and the graph used to calculate them. - diff --git a/g3doc/api_docs/python/s2t/create_expression_from_file_descriptor_set.md b/g3doc/api_docs/python/s2t/create_expression_from_file_descriptor_set.md index 77c80eb..f8f9e0d 100644 --- a/g3doc/api_docs/python/s2t/create_expression_from_file_descriptor_set.md +++ b/g3doc/api_docs/python/s2t/create_expression_from_file_descriptor_set.md @@ -90,4 +90,3 @@ An expression. - diff --git a/g3doc/api_docs/python/s2t/create_expression_from_prensor.md b/g3doc/api_docs/python/s2t/create_expression_from_prensor.md index f72995e..582cd0f 100644 --- a/g3doc/api_docs/python/s2t/create_expression_from_prensor.md +++ b/g3doc/api_docs/python/s2t/create_expression_from_prensor.md @@ -61,4 +61,3 @@ An expression representing the prensor. - diff --git a/g3doc/api_docs/python/s2t/create_expression_from_proto.md b/g3doc/api_docs/python/s2t/create_expression_from_proto.md index 4a720e8..104735e 100644 --- a/g3doc/api_docs/python/s2t/create_expression_from_proto.md +++ b/g3doc/api_docs/python/s2t/create_expression_from_proto.md @@ -78,4 +78,3 @@ An expression. - diff --git a/g3doc/api_docs/python/s2t/create_path.md b/g3doc/api_docs/python/s2t/create_path.md index 65ba841..af9bc12 100644 --- a/g3doc/api_docs/python/s2t/create_path.md +++ b/g3doc/api_docs/python/s2t/create_path.md @@ -91,4 +91,3 @@ if this is not a valid path. - diff --git a/g3doc/api_docs/python/s2t/create_prensor_from_descendant_nodes.md b/g3doc/api_docs/python/s2t/create_prensor_from_descendant_nodes.md index 4222a52..79e0ae2 100644 --- a/g3doc/api_docs/python/s2t/create_prensor_from_descendant_nodes.md +++ b/g3doc/api_docs/python/s2t/create_prensor_from_descendant_nodes.md @@ -79,4 +79,3 @@ if there is a prefix of a path missing. - diff --git a/g3doc/api_docs/python/s2t/get_ragged_tensor.md b/g3doc/api_docs/python/s2t/get_ragged_tensor.md index 04eaeaf..2d849db 100644 --- a/g3doc/api_docs/python/s2t/get_ragged_tensor.md +++ b/g3doc/api_docs/python/s2t/get_ragged_tensor.md @@ -83,4 +83,3 @@ structure along the path. Raises an error if the path is not found. - diff --git a/g3doc/api_docs/python/s2t/get_ragged_tensors.md b/g3doc/api_docs/python/s2t/get_ragged_tensors.md index 3bf2171..2d9569c 100644 --- a/g3doc/api_docs/python/s2t/get_ragged_tensors.md +++ b/g3doc/api_docs/python/s2t/get_ragged_tensors.md @@ -72,4 +72,3 @@ A map from paths to ragged tensors. - diff --git a/g3doc/api_docs/python/s2t/get_sparse_tensor.md b/g3doc/api_docs/python/s2t/get_sparse_tensor.md index b6ec8ab..9ab4355 100644 --- a/g3doc/api_docs/python/s2t/get_sparse_tensor.md +++ b/g3doc/api_docs/python/s2t/get_sparse_tensor.md @@ -84,4 +84,3 @@ structure along the path. Raises an error if the path is not found. - diff --git a/g3doc/api_docs/python/s2t/get_sparse_tensors.md b/g3doc/api_docs/python/s2t/get_sparse_tensors.md index 4c50fb7..baff210 100644 --- a/g3doc/api_docs/python/s2t/get_sparse_tensors.md +++ b/g3doc/api_docs/python/s2t/get_sparse_tensors.md @@ -72,4 +72,3 @@ A map from paths to sparse tensors. - diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..e69de29 diff --git a/setup.py b/setup.py index 45118d5..4705b17 100644 --- a/setup.py +++ b/setup.py @@ -16,8 +16,7 @@ import os -from setuptools import find_packages -from setuptools import setup +from setuptools import find_packages, setup from setuptools.command.install import install from setuptools.dist import Distribution diff --git a/struct2tensor/__init__.py b/struct2tensor/__init__.py index dbf5aac..1202eeb 100644 --- a/struct2tensor/__init__.py +++ b/struct2tensor/__init__.py @@ -14,41 +14,47 @@ """Import core names for struct2tensor.""" # Import calculate API. -from struct2tensor.calculate import calculate_prensors -from struct2tensor.calculate import calculate_prensors_with_graph -from struct2tensor.calculate_options import get_default_options -from struct2tensor.calculate_options import get_options_with_minimal_checks -from struct2tensor.calculate_with_source_paths import calculate_prensors_with_source_paths +# Importing this will register the session handler for PrensorValue, and +# tf.compat.v1.Session.run() will be able to take a Prensor and return a +# PrensorValue. +import struct2tensor.prensor_value +from struct2tensor.calculate import calculate_prensors, calculate_prensors_with_graph +from struct2tensor.calculate_options import ( + get_default_options, + get_options_with_minimal_checks, +) +from struct2tensor.calculate_with_source_paths import ( + calculate_prensors_with_source_paths, +) # Import expressions API. from struct2tensor.create_expression import create_expression_from_prensor from struct2tensor.expression import Expression # Import expression queries API -from struct2tensor.expression_impl.proto import create_expression_from_file_descriptor_set -from struct2tensor.expression_impl.proto import create_expression_from_proto +from struct2tensor.expression_impl.proto import ( + create_expression_from_file_descriptor_set, + create_expression_from_proto, +) # Import path API -from struct2tensor.path import create_path -from struct2tensor.path import Path -from struct2tensor.path import Step +from struct2tensor.path import Path, Step, create_path # Import prensor API -from struct2tensor.prensor import ChildNodeTensor -from struct2tensor.prensor import create_prensor_from_descendant_nodes -from struct2tensor.prensor import create_prensor_from_root_and_children -from struct2tensor.prensor import LeafNodeTensor -from struct2tensor.prensor import NodeTensor -from struct2tensor.prensor import Prensor -from struct2tensor.prensor import RootNodeTensor +from struct2tensor.prensor import ( + ChildNodeTensor, + LeafNodeTensor, + NodeTensor, + Prensor, + RootNodeTensor, + create_prensor_from_descendant_nodes, + create_prensor_from_root_and_children, +) # TODO(b/163167832): Remove these after 0.32.0 is released. -from struct2tensor.prensor_util import get_ragged_tensor -from struct2tensor.prensor_util import get_ragged_tensors -from struct2tensor.prensor_util import get_sparse_tensor -from struct2tensor.prensor_util import get_sparse_tensors - -# Importing this will register the session handler for PrensorValue, and -# tf.compat.v1.Session.run() will be able to take a Prensor and return a -# PrensorValue. -import struct2tensor.prensor_value +from struct2tensor.prensor_util import ( + get_ragged_tensor, + get_ragged_tensors, + get_sparse_tensor, + get_sparse_tensors, +) diff --git a/struct2tensor/__pycache__/__init__.cpython-311.pyc b/struct2tensor/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86f3ebe5b213591fa8b921da909e38d8d2432518 GIT binary patch literal 1948 zcmZ{kU2hsk6ozMw4Hz3RU<~+^#%?Rp$fUXGhZLoXnxst=MN!kM>BVTZJD9C!nAPl% zkN%PVgsMNL(q68(?N#rPau+F5&Kc~4F$TRlJTvp2nZwRu|1OtH1U`RE-w*Du6Y?Jp zE<_ zCCf#zAvfu!<#kb(6mpthXlHqj>d)}lT_#zk!-*6Ud zyOhA}!Qw1?6`3JB(|$YzL*tAa7V?2V7KW+VL~4Y2M>u^koF&L;6iHvOV90}Ww45y4 zI#FXCKo!0jqVF)!9IN85p{^u>1K2<-d08bmN)JCS!RY1^{~!_^Dop|%#YSnCa0BlB zCGJ=&qhNjGVjZi2;p+2??-z#0Fs}YMj70cMh5R@B%wX!|GGq1>XT{>Vd=&a4!$x4n z-}U)*KCP5t{wRdU1pC4cV5~gl{=hk+XD6ZcSIBo3Yj_dbKRe?go836r>I;)t zkhLbUuM>X7>P#ljddHIjW1$Kdo3z|+`go^02W_QSAxw^zST6fRuc0-M*4{F!?Ya5J z(1q)^zS=3eh_1G~>!(K2@SY))<7LY%lCe0oouKjEq&LuBckS31dJ|X67%FIKUa`#X zM_19Xg`tL_j-i2J8$%PqWcS)kHa3yaJLqa*Xk+N0fA4z#OJ9unr0gvIL*7CLdG~XF zs3h;lp&V=Nzv(~V$+=N+|CQ3`aM1zpJ84%&jKTldFm@Whc~`tERX7&>p6Y{xeVdYR=9Q8AsjH0qkh;pqlhjp4XzB|0e&VI%mw9ERpSnWZ z?_hJ%Gx9Wbg*L2sUKu$~T`kUF-kp?uJ+F*BN?m2-r_@zOZl|s?aw~O}kzVR*F$>TC eJ|*AHDql%VZTV9j6Oskp&Di*rvsH|L7mv8TMi|J19{VZBPUY6eGw1 zf}jQZ`_8$~nW3#XL9o3V9iDsMkMDf%bN?(D^b2^dm;Uecv9p5kKj=k&0$SwpWl0d; z6Jo-o5Oc(wGmcs3q;poB6lbMLY1TFAnsraQXFZc12jz=1-dW$IkHsa#{gZwccOf2_ z46wKx@tVmR7Wd2qCu`-p$$I>6@CiO4wkzh9McIG7Q6ET$9uZ?AS^=7@?w9_nU#$o~4gPv-xx9H_8c{SEL)81I{j%%{@cRs!?Mu+j#Vmxiec-+>R zTW0awUoW?HT`fJ8uU8$$*Zw;n-xj08devh*9e}RhPlYZEY{fnoM8#^*R|ER$j5W%V z9Jt=4!$LY3>q6>Ir*K|~ZM!4Hy5*h^L_zon|M@eyi{))cUQI<_FU#A3ypY^m zR&O`U+ljoMioCv<=Otlqcl1XZxblS4P7GxKsNIUS0IX5y(d#pGA# z6gicOCljf#-+%d9JcX<)N>sTWno1_p(Rd;inx0QgrKv<{I;n)Fv?k+;tM*3iP3gtt z2OKJ=RY zXeyPQiesT-7-$+j*}Dr3#KUrUh)viW7h6Dk!J*KV`LsVelS5jYESgBC zs?;0|huHFC2dMCD5*sv?oSjW3{IhcUS~3>8CM&XS-h5>vG~mcpB|3L46ivj;RgCDf zvzFl75XT*qz&!o*kqtx{LOej9> zznz>9-NG!X&W%L!R_NBXD1Z&KQY<@&-NJlAm5oO=I319~SHna8_;jdxgVS+&CYB0e zDzsinC4M!Yh|YxOlw>-2Wqvx8Wn`U@ot;5S0x5v{cuY>D<5YA`R(PRs=()XJo<@jnR^GcL9cz!OK0Eo^3LgmC%Tuy~U=WyKRSEI9Y zGx88S`}*3ZhIn6CLV_nWkIvW{hqPU>w8?fKB@?n7LqCL#6m8e275x|k6fwe|6Ji{6 z7p3v|RJtf#0GcX#E=K3(u<%9iC2UL*n_Tq1B)>9`Vai3<<#`Nm&{1sA89yPuSYu=6UBg;pDKE^dPUzfh8l^{2;BT!6zk5~55<&$!!ceL{?q$o z1;HEQd%}Vv+mv?TMSHFr@eh#lk)A9pI?_~Edlpb{!TCd{9oN4lE{W6N{qA@dgmJ-M z|4nh)30yBlH_?HZF-nH7JxzZ6^@-)W1&!#FKavj ztO+z@4qnGVmK6PfheBUuYp1OPt+`Hkwe2B=F9bWp zuRk6|@UG)M;i|A8e64NK@uqVXDV5KH;{(+H$Y|po=iT;~gbxLriTbN@nv?2CQUUDC zgJRJc-d}XQTC6u`#>q^IfISq-{_45mYsp!8crG?OuPD*m!{c)5Mmjk+d@HHkNX=D|Y1!Y5{u$WEO?M6d7#3@0^! z;Fnn;=ReDT>tne<*Mq>&Mqp^|#e85`4GbH3k7;QJMa73d#U`40OrDN{oav~e7>rCQ zGB6dN@SvpBqNq|w0d0fQi68|A3(F8Xtq>zx?D~{ZDaT%PeuwB|%s4?NRN`b>xRWP9 zA2LG|R{H&q6l0jPHL|IxQgS#7+-77^HnDmcgrbVnZBK>&qck=|$bk_@CN? zAS?W$SMcvHxEofEti=h+jy)0(d_qCt*ytxnIYDrDq7v@hs8878@d<#>8v&>`9Cv_} z$_dw^cvV=G7Q}JkYlDlf1!)ykDxU=x0b%81CxUjA<`C}gu@P0N8d3c!7HiSH=t+Cb z))%A&_XkM-h!&DPi{7~KR$#%4Q4X*gzD0kUaEA6Q`0RbZ8Mx_C8W;R;20pFS3Dmwj zxD?QcI@Z8Gp9POb)c-doe_n>LKm(PJ$fJ9%pA0 zc$;N>EY8*qE_~^np3cv41{xqoCyil7bZ8-hpGNVwI;i@bm|Xefr_7y{-u9B9w4xk{dkc{`62%fAbcPTL z?9q;Vh4dIH0{Rfg56U*gGY3p0p-<#9$nYYyLzo0qICzr(*2b2i4{CRB)b7sL_N%r1 zMg~J2N)JlgCQbN*O?p2f220ZjM%jW17Qo=xPb#L{WKOpxLi5!|4Kh!&EP5b8r4?B& z`q&YQfD6?57={%;@2r>5cxo2{Q1zUDod4E6e{^ESv)Yjl>{SDMjbt|IL08c&zY2~h zR&;YNkm2m2n2FC7o$=UUO&RhodbOdaignSMnTX{)6(yRI5~c~d7}v?UHI0*-i^i2y zQKAY(7wbzIqT$vcpwpv08Uxe)sXMM5MWyRh`@bN-WJ`TQ?J4Z%dKZFx@T3|%nH?xJ zw`PY6wXH-R_OKuuD@aWbq|Obgb7kaTAODl%-#+>N$-FeEN`pD}+icudaJQ`7S|0?d zX*vl~({vJ~rs?GC`?7sYyMNZwzOrMjFW+)NZ8<>vP;Gs7u+Xx*Bsj$$7QB8V8(TWN z>G3V~zW(Z+SC?<+JR$rF+jeH-?>d)9-tpe^=G`5?Ea7+I$N5!P)S@ z*|Xv7Sv|5I&pCVY&eN*%bk2FY;A_u0+c|02OU)8uo3bvLaU`!H=M9LwR=QxSS(Y5& zt{kK?slMEHVp-0C**A@C!J`UBj1aG;snm2V8kR0Q##$;VA#}w9siBz6V!`B^mIVs3 zE4LYEV^y1ktI}_Nia80fHqTUzX46m;vd$8u?Bu1{5N)c7sLqUqXkOH^jVu$J%*gUH zh!9zhlcu4IF~yEWCkV1=-G~LJeF;rV$+w($w`qhL8Vw@ykgXYgcHAgBZn6#6LP6R_p ze~bv!jZh>$8C$4?QO?3ip}O|iPejaNZrPKzy^gloS^4?7cmg12lMlF6g_2&GhxUwt zp$3Ih^GXU?sAzcYipla7Nja>pL+wpJtSaHe%3OY}2U^RE&{$Dxwdr*h&YNPsYl& z0d#0jNxMu&6f5&x+`&Lxjw&Yjox1 zD&Ob~Bo;T9NwH8D(6Vw7fo+xQBGb{t)a?ijE&?T^@1%~Yy0Yr zT>IFHf5pF^`RP-6-*c+(xt#MkZJKnLBV03SJ6*>7{uP?CQSOo}xdVB6U0#MCNF$_JSOpHZ?|U zariPcVA9uu3v^`QQGgPsmIl}3~!(#9T3ITv#`2%@!SwA_N0YY}ga)c|_yiggi_gQA7b z92*M}Jb{dZcBIHa2uqiJrNQw zXKT=k0kC8#h)^Hwi8=9R0BOvLDT(pG&2Qe*m9#`f4wBApe+c97JnsX!Xabpdu(1Db$FLleMor9zegEArry{!rNt zh9eQO?L;CLw!*M5{7-6vxE0`AL2P!L!NZ#1FMv=7)*N~79@V?2B-*kW%uj%`4xqB~ z90DjJIB_f1fR_N6a0{|>1bHAOQc9kgW*}i5a?z`gF4Ml%xc@5}ND;Qg7;A;LuI%`o zOC_g1nuiT74;qFx8iw)>VYMM#av|dpP302`Q1(K}D|DceDzy|ET9EL37XLV^lnAsH zIyE*pC9aUcDKQVkhFr&A{#0i&793^?`?jTU=z;hQCoaY=SwsK&wLdadqNhxSLWjYWE6*C5c(H`IUy5#bOxeX49+sT zjHJ~vS&|UzVl*@cQ!05tuntL!3R4%gkWLyMS$L69=&}M|9X3jt)!`=ptMgF>t`)FP z6EkuR{F0rcBDw6G?F`XEu`{#{5$z25OvZh|IWOSM__Z^`HHMnX%n>a*uGr*kesuO! zf}s3L^@H;(beakw06SiHhX1}jvT|T0y6jkfaoKGoGCSc=?BseBh)SI0?qd)l>*J zF8$sw+(JzqAK&*6t)}lUtY7}|SM$4`QFlFq)O_6;6=7gp4UFfcaZcN;oHfxlb~yeT z5we4<;&0_+Qx^ypjsk6cWVBaS%0S$9@GdJQ>AihwDdWi5OTi5R^*=J&prbR9`LfAY zkA}WszAWm80;w8$TO@xc1HB>L`nt#^0HiWWU=czfLy)w$%GW7i=+Lc@xwI$}gI6UV z5+p%^m51>}h#Ld|v{eEC%lJQ3fjA#PfYT8i`fvLV{oqTRUE5dE@4uGs8dkf8H`}{b z#^1k?Zy!|K2RFMztIg`pgZb`5V0<=PJ63w%-}6rLUb5s9!Xu9aDct`F1trnt>aPR> zG8Qo)_z5C*RuTYVY%em62Uygg(q)XGk#1r?YiW~Rvg5=wiMHB<@`lOFzTtLrussi( z@63#GEfo1meTf_i&nik%fip=eH4oP)<_rjQmb#XlpMeqs9$5*QISmcd>Y6SfJC@Hn%DM%lHkHt`QDh-j5qTxYr(Qo$6fP=#YasQ0Lg+d%HD8#xl z5K?X+W(ANY-LTI9FCx<zF9C_rE`XJ#Wc;s@|9Qi~* zP!im(Bb?L*HEEIcoMPtPqAvx365Ppja$?FjkQfOG5DdD802xn4f4Mofn@Yj^w+L(_$IhqlRGelBbbRTeUh;_dm(Kw z_RyTV()DJm*A4Tm{S2hSpHj!K*DtHn`lKU|i7yHC|5N%P^u-=~A4(t0&aUfZkYUet z39dF%!xDAU4!W;5{tsU#8T-b&U75?QX)i*X_$d!*v4ScE>|!Zq%<;j z)5>7bVsF8c?<+rsy4pjo#*;VGfh0d1uoi8n`ERXZ#qoZP@yd=AsDy%26N9x3R}_6l zt0tx4!toppElI(bD4?w<)_yUWOiy4zE;D!1V%-_2q+(`*@?|P&iq&q;n7l%%1f{I+ zVo=9uOu8=CaGa)jH_`gAUbD*|u5RzOwE6x77ZV z`Px%z?J1x{&!CZbDqs7wTKjZKMB>BR#s{^18?}A;+5xq8AP4o73(2g48GzsNH!olK ze!tptH19j6`i_+Z$<<%*wXO`Qo%__zLmR$BIp3i|Wox3~A3;vZEd(3N246{&4XjDX z+)k6N4Af^~N;`$LzcG5*&s;AUd4%*(E;yJ1gY7W=W)iP5MiBSJ^6s^0>RNTSy`Nxg zI~v2;YCXh%s}1ieD?8trzc+6rvBibKNurjDt~9ygGXA`R>_UF!y;YN6c@BXzQ{5AR zt$Wv{^%}cBgIK=xkczN=M6Dmm2cA*`PvxYic=v5294Df)awJmpM`U2AqC=@I6m-Ojy}g`^-omRyM9pl$12J%9g_+BH)0@PyG*JSNbq8HAnU;S~|5T!ckq zi-0c-Eyoo88xKiP^gKQV$SBti-Vr7x*#)g&85==5LfLa&uS?oOpX@jFdYa}iP^DCj zNqlYPNxN8EQQlG-ZIafp4qC^0YhN)}%#XYP@^Jkw)__Ysji#z>qO-)HOpVSfaY3>CE zeVysW1j+^N2fra?LscY=5bb~5gMeEEO)VwUFJjyH6BD3b_&vm-;Bql7bFmR_#?UFj z&z`uA2x;2Nkp?@NnRfH2a($_;3EM~bq798g+uzVc6z$>m5b3YFa_q;T$~Iye4;wKt zjCOt*OXSFGA5Vh!fs!YN8xsoT*BKce@d#(3&rOA)GXEn~EFWgEp<*VPo#lhBC~}ZS zPaFoQI1~-_YM>`4_3)E$n0jeIVAHk{X<(x#v}C7S<$m3;$O@)0yt)b8CR>)-Yb&Uk zBl9=R0In&5r;-X_0vASNOz`7YF%wX#=)fqmnqh7xnY@9kC)HJ8no6CV744dIri$a& z61aC{3g}<6ZEKGDOqp!6p|cby8eX8)y^EkI!aB!o&?4vBXyJ=q-S%eFd=-66g-yVk z=_ub_BkKt5>Rz-0gKx7?KfIRC*Nv!kBiUyQfx5e+Z=ZVW)OUJTzVhv%_lNRyKkI4cs{sO4erdID|mw0-`}ih&9#l@YmTcm z$8+xEg+N30B3qoIivTM%QT@Q*N0V422mFvP;AlHd!3rtQEQ$aS^j!H=b3e+~*_G@S z%O{xv{;E{jCIs#>WC4%sGj}hN`W{B3|NI=k9}~~8`;LbGk8D6}irPs(M`s$kbe0yH zAG%2rCVSymhC-zJo+rh8R8uWd3peTV4m^}^BGsZ!q^$WWNQbwFvP4}#;n;mv!dNog z3~saT^vgTkwF;NrrqXak;Me6yCY#7~m%$O&&*<}k#+vCY+1X8guvUcDI5)I4m6gxQ$@$e2TrMgmxJJ293o(^C$6m1Qmw^rg@gC`zNy0G83^ z3bSBIA`$Ka#3>IF5G?f-T07rM{ENium-4McYU|Kapb%)y1-5V2cjmfA^YzEo`r|qO z@q)KW^@bjJ`!~G(Yf|33PxbE0v0uqYZ9(bDd*F17$-6{bZC6nYdFun@ya7Od)1v1A z0H0qJmjumSOEmFTxn>G_Ag9v0Aey?X1?NpL<)pDP(@Z7^jt2s?MdGT6ZYv&WT^tLp z^1F5bXKq7^Nbqh@nh=1`#&>Cr8fA|ptQ7U>we z9bGZkYz9IaI?8hKVH@%?Bjq4r=ruIGLN%p{Oh7V->c1q(?kxYUA6QP`%Y46Q^~LY* zzQ0>-8&KQAMke3dc~k-Ie47Tv0G{+&OvoA3G#X(OGRY?69^rRy zs{zCfT4zm*1}_cvTX}J5?Cu2~N64lQ=9g{$4UE;QPiG!Zm0e}}o@_{T-cxisKr&h1 ztG!VFpwlw15$nPeRbmT&imjW~@laFR7^TZM!A|Wex8yfdzhI4M|F=-U>@8mr%~Vxw zR}!==n*OEkNRIth4!kSEd(nLlMpG?`?~If8T5Wrn9wXDXMbER;OAlN8%|I=0byZSB zhY`)!9Z?Ymj;ev9Iq4|J0!)PKP_i`1*D3fW1^tRWzaZ6=oKA5UDJ6D267hQHC{OSP#DSIP)olk#0%Gge zK7LI-e2K-&gH91XcC8_`buYX=5Z!y z{AKkdnnkQUsHqjhtJl_EQ~RDS35fr6oc$>e`h*u8mmDRJPb8d3;lsL+S_j=KOK}U0 zoh!X78Fk0V`qdvNb34weo#)hs^I8Q`{+4`ZUEYEgwmQU7j3A>9kE?xW$(ImnB75L= zh=gM+0=6qt*bn28wAf?XIimmNZeFs>wDy>Ua@xtb0`X4Gq=%frs+WR8n9Fbm^D;*; z-;(gEW0I5}KIH_u%>2V|ru3k@>`avll`Za(E3iz0MFckqyC)n#s@+HB#uIKsLQ(6S@!u4o#1JvQ%!p1pBR}_;ZSbYTgA1n1NN=!RNDg$olAuI*; zGSlgBmvFE~jM1mnFV^bcBH?+^97LvOUq_J@?<1G`~;S4S?kvgLAb_EMTik@l45E=o2B38@V7TnjEJ z(+r48(^}{)m6auIgRv9+s(4`J_3MfD9&nlt#o)EqCn$Dn_-oNL|foTHIn^<1K_ ztpI({cUPt$X3r~xLkJ-;Z%z12O$-J&1!fW{9q^(!mz?9o_XWIiTFfANKfUgwfcScy zn`eORPABOsa9{PHt`=iQgL_`liDaC;X1V+U-ZKMR0(9BA$`WAz1*$L=sJ8}|NKh2l zNv0(^p07Ef)||+>PizJomiOj^+tuK9s8Rg&w3-WRIeV{2 zd0)5c>(2RhZ?+%#SjxAbRNGH31vl%quPo;4_NjIIAT9egXe9J@D@_wG`?yK>%LrJxXKU|m%UIF-Zrjfa+kq1pYA3y7HqBX5y&KkHDC zwlAEFgnrHZHQSlf=<7X7-)d%*EBOg@;!i;K{Krb)Qco(1<-$Il#fOv>Ph zE_xO`Pg<89c#pA8EnBXWyL`NNzqsI7e>3L7p~wG0cV_HzaxOWOyh@kzaJ=z3GwtF( z*O1YrKiFrOwtJS8Ruu+?u(FQ1+5}8mZ6;uHL5bbj>NH(Ht&D(GV{9KQZftewY6@pj zoRJe8QRvsI2~pv~3r7fa0)P;>mSP+)nWI}`0}J`Ih(N|ou5S@`cnVDbI|$BpxLoVO z-}jHK-ddNIJnuFwU%vO%d_#}g&;v&UYZ*?r2z-A|r__!R8QwdGWKzc|@H12X1TV^u zC}0*I(&2Gs4AG52FC*^7dbSl2AVr8dxSf9=@BjDClsI=I9)Wf<=XlmwC&qy+m~+}R@;X2_50O&2rf?7Fbm$4mK;k5 z3O@f*^sUyV)`GYFfp^D-cgL!9-@kU?elYJ1tKM+V8^#Um&XVA94HSai4}v>4f;;oU zU21R_FtppRy1O5^dpF#@tC#ccJ*s<8&b_D5*s3-{1d}|tu!LZ#w@@GYe&g!+{Y&|s zBPxFNPkrpj*B@JQ74|=s^R|DdkrX)jjv=*U26mGwN;QbRkj*w$VxAGx|%J3efSR#M)*%XcfnYxiAVcm54@2Mz5 zfiyq6kh}9U$NA?prYTeeimWc>K5S`&lLwbxN!aCMU2j(pyntxF^MZ;{x~NKY{m_(g z$JK2I^U@(zI+W9Y9N&?M!?AyLSR~_xW_3Np^bdHie9H9?kV?hZI{^`x4}Yq=UzAIs zV3J3e4b_?XkDn_4kvgjK-ks*5gU6JPY`3ZzDq`(Y z+sxkr(&ekfXIk*m5K>N_{LHI-uMV=7Cg2y?sXN;<(1}k2gvc%dYix=$TEur*bthcZ z^v_0bUy-@?gbo%Bic?f3U=D5#Ol1=0UlLGFpat-d^+E$amM6Lnu>gW#Q`>{4gBwi; z*GKYAN7bgI+3N*gQ_i=e07=xlgI|zd&E&lYRqw%^_uwx}9=$S$pTlU^204Tq2=pI^ z&@DG~uHv=wDc5?u&M&`FUF|6^vB;I7PJGeAX$pzu_-esv!Ff)YS&bH9DDAbrFIc|z zhjasb+4Ll3AB!aZfFN|i*7Aar(plOQx58XB#OSM7Ci6~-L9)1Ao&lj1{D|^`s1e)} z{AlKWM%1*=o9V8$%sYtas+JD-QsbW96&duCL@ZEDkS0VJu)@U+U)M#uq)B0HzV1#i zO<$``(@jCGHN<%RA?^{+e1D9=lP3tfQ!+pYXOMkJl=YCBpPIt0J-u1u+o<}r*)j*G zB;BAfT7@@-Ueor~4yb_fS>0P~#6#q=HbD!7g`(N(8ARgA3x=zs(GDFaJ3%7IWdvyy z5yQUyYkx>o|J+#H7HgucqWLP8?DJb6XzdUe;07 z-=Ik5(C3)hf1p{cp*GDT1T)PWzyHnOUykKHU8<)m=jn%Q(bDD3*4?YIeCuA=p0bxp zl5HkzI)wwL)dSD6IYQvuK?2_nclHvzBjIl7ZNoLTaLzM;U!i;F`(IzXq;{Xk#_n9l z2U5o0X}{N=cXzUW9u6Eqa$-4>ckeEA4^m>xr(AUIUxHF=o9gN=`6z{2wFLY19=AI#!WDqR zLPN*O9<_5=ZP;HDDZ$<;E4X#9ESt^w=<<2o{R;Hv`i|rRM^x9*l8#gdC3N9jnc-e&O%hwRPpkud3nG z{IyExl^mn19yN3jN>c|sGwea_goFC89xhKmVmY5!^RGC^9S(SNR1c{N<$PkzpKuO4 kpjxa9R+5h>^%FBI<~Zwc00&nOS6+L>az5E2pCP3G2N=J_C;$Ke literal 0 HcmV?d00001 diff --git a/struct2tensor/__pycache__/calculate_options.cpython-311.pyc b/struct2tensor/__pycache__/calculate_options.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fa90fec72c0dc2b849f5a22f8698a7eda3e593e GIT binary patch literal 2965 zcmb_e(Qgz*7~kEywp?v(jffx;CWVlmlq*q$aww5lA(~KBnvnEmGq*F_Te`P<%>7#Khk>yLas!5My-Po1OXQn{U4F_xpP{_-^ddTUEMYBKj&;Qn_H9cz(!O7?tcUnz&yphw((xTp zyj@%>e?RUTRY{<=(9+)RQv(#($)(}_yGjqn^-ml?pPiCo~W21(s*1X0l7q3hM9 zx1ym`lVQACfw@c38uNolo=Vb)Z05NUaU@x7jWuEs_-izg7$@VVR7oI_O5aGj!SvkY zej-)wc^nx8Zqy~x*NrRTD_Bku86c7MejBWoCO{0^69ZJ-3WO9hEJ)aD5Qgwr$!5qs zOp_$3zXS!d^Q9IKLS74Hmzq_qhlPw785mU-_-s+7;1U;t5hoF0O#W%MO0FYypl6l{ z_PH{Q-pvOY1SP*THW$sLCSV!MY5BiybKOwAZ<`d<2eptRw-w0MJ@E+X#X>Tj)=o8~ z)_fVj2N9h&qs~ZOlk>b42g3Nw5N(t4orNI3LcRk!b|6nZj$-8^B}x2VPA^EcZ15-v zJiTXQAj&CsP{b=}1WC}6x&2<+{XmAoq%cOtGkQraBc!%#lT8N|Sut~&HYhk}VS z#p#`v%H@S}QBkE-bX&A12B_Aa`8WJ17urQFLx0zxru~8YvoB`raYN2FMI%*;ug%U& zy^_Su+0|IBAa+mA_FX3fqh96h*1Sq{t-a56QS%blU7yVBOQkEB#jcvl{-*WYp%a_) zJH-Q!yYIt{#s?qIeMs;3x6J=j=vrP6NDK%Q1G|q7E>AnzGrM8k?2YcD_pmowhqiF2 z9nkNkbkf}(vbRNu6rxxVLk&-MC?Iati^Gt zUcpi|iROW2%oukSVQrcN=;&)`HhX^^=Rje&V^MQ@sBrP#e8?NlqT;ckIrnB=C(_7JEgfTXUnrc#vshGg0OR=KPWQ3!8QGZ4(tfA}6zN3iH4{PqNvKKf*q@pNNIh9XwX+FxQ zg{Y7gqheZ$N@-8jllDfv9Bmg;zH~#hf#pTW`=frAmmnXA23Xz$`Cv52@?OY?q9K;| zr5e*s(I%E}NHwR!(J;&VQ!UX};2lV{rQ4(J9O4lc)Sw!|tuq~EaTC1Sh}&m6YsyWS z$HAGda!q$RgA)9d@ELs7P53)cw7UVRo7HA4;lRwMa&2{l8pa~_&uo57OAGeGsL)K$ z#x|($uJTCmHF=NPdP$1*s!h=6u`;SHPn2F%+b)UGKDFJ=U+uu*nf|gAYK!Vmc+^hl zDXU%3vuS2axveS(%;J_&RJ&os0PyZ+{07ToO8hoK>yYX{g(5u*`;H_p;Eb$fWPCA4 zu%;!mneB31NyTSVijHM{8p}|h(Xs?Osj`-xC2^=y^l5G68_GE1vBlStnr{1#DQT>p zp3SAO?K_^+fs>N5eW&pGS)7SuTR4SvTVm}I&Tgp80flg-5X-WG8dUnL5D*l%;{NThfmPHZ-6XGg3~yoYqm#c zpGBJCBY+qEB3sf4#+vOdNnv{k)@Mm3UfE>OSA@R+{|CPYat-Mmyh_hZC7(xm?iT0X zpg#pJi~ORX@|RFvSgGY%W`I6}+|5eqP+fUZ)V-BPlovmP`EFL`T$J*BUP|z=u)@Ol z8%%OdR^$|{j+{+Y){%YVcD9^MgR(Z8PAepdwUG!%+F^otg(nmw__xCU!CoNO&=ks} zGhK_^6TX{s%k-y?Joh1gwfhvhfg(I<1QvE{Qj4*PW7QSD!Q1@EPMf=^QMJhOcW;kQ zXVZ8zr>18KQRYVvVeOor&5h1w$vG{j#PO)6li9ew%UQ*`QguJgCH3i;vuk6{uFcJV z=if(l63?b{6AOE)eM?(Qi)|dVFO`LU+Qdk8|O2l0LG$-NE+Z}_dZFCDWc@2?T#s|*QiU2WiVE;YPZE`gQLyL zGDbY44S2Gpw0-5TvL%OAxL)^~-B1~-d0?Q-9@0nU8hVmkwK^qRp)37sGeDNmBR^{I zGuvLUJb~pO7X_}p?>++ZkdmK|E^WO$Srp3}v%(!~;T@~t9fk0y86IUix=;1_kP?G$ zedIw+gT~-kp>f=79AAo94dLbIjfNhB?;$cw=i02)JeYCjU<_=^cth9H_t^b z&;1ai&mBT%-uXbfBISh(9C<6xuTW>Q;khEYjfg{+r3A$7OFlhN z?X1)9EIsujRAB*H++6CN3CPm9D1sw;tPJ(eg+o`nuEz9Gc}9ftyqiC`9Z?wswI9MN z1c786;)mu$3OSjHr)Je;W{NtYMDilI8oytb50EKsFT7=^+umAdMg6g?XJy#svpLqg zjjrKTA~I4pp7AX407e8u00o~y{q4LnHxtTXsnpCdPdPz!63*dy z*Vc9U*wGQv*{)NFP-Fo7J&vX$0->nJ=8`s-)999wtFTtulNeSi0O+MEPeZRL+EniO3zjBc~{zRqGfjg<^IoWyJ@H7{$JB$ZTh0jneOQ6>i% ztO9?1j@rM&%aKn56f5iVIn0ckVBj{FOr(mFOm`xTi6t|METE^eX#md36nH!K z{ZweirBofu0Lwam2Mf4YE<@wDH|XRH^EQrBsaIsvQH6jMxsuXV>PU`71ba)2RB>$# z#LCsnXoojckulgL;2ieBi)|&+3%QbmkR3DXtKq9r##Mw(KcDP}5pU2@2Y~OsLK z*vrCE&Y?@nJ@Kg(<$8x1Ij|O)T#ZZ?B1g=~k%IWLDZXs51n8jBpKWe349qq-Mz*j2 z2`x0HPoWPaw!2R+ax5}(mwbz`vsa*fN_rfKZ^0n-!HT3RlD@zzM-E z4+ke8>PC*;5D45_Aa)94F(4w@d{V8c3`4W+QF1w)QK7NfaRRYy!ZAY47U?<9_Bx82 zw*xeWlnQKTYgBg>WH@op4p7Kk>aY-b9DoLtUxNEJUC#j*?H52GgjnsHi%9TxGjjRq zdx6m9sjHbnpw|rae%fLMd+rO|0L8lleg8&Wun+7Z*k=ty%z?4H;_ARYV_=^Z?z%ie zf8X9~d&&xLxvt(Fpip^&Lgfhxl_!?YtwUHzzn5NqwcwFWk8F5kE7WCpx|pT7jD#K{ zXb62k5qJqZ>~x>fYF$6%D{JjqS8k>P2i+@n3%S=C(8jZ{f*d&r6GkLLm#JpeXwYTr zN$7=3=~P#PJGOqm-mAx~IXe6#5PHlK{cy~x5Dy9jA4$mDWrbTl>RM4gN&R*E-_QSJ zuCVvGxff1docAOL+5`&f1X>g-Ln#xc#E57&*BqhP^@bz&B+P01SR1{1@G!?@i}cJ& zUV%1pn35xu&`?|tLdYO(p+u#vzXJjrr|3h$#-(AavA>8y!eMZflrE`7(I>%fdJJhp zeFJN~yH!yc#FFV;s&t*>?!%%OX`-Y3j;l_j zvU4`CC+w!FY<4PzM_8A$vk7NWJg3<`ob87zp7M2_cmfuIxfI(!S-F_Bz0S2FT+p3x z?j+g80|WSW=p?-ae7$rB$gqjsbb{iPDg$PYP|NTUYdZ=Ejx85EHaRNo1KGdx?Mr9X z*%Y22Q&6N3T00Kg9MsQo0Cc#6oPl0iM+c4nC#|DF!<{VjjNwie8Zz9;LVboiS!j#l zP8J$9+{r?_40p0nr%_JpXxMP4Vk2trUw*~t-(_uyT;Fx;c|#tzw(PjRV2&O#7}B#%!cuDNKZG{@f+Vt z+b7*-*E2-~rQ71|(|3My2OuXD>dBra?kERWMMle@*)UWTXmLH*X9l+yC6<47+{qLX zExPDXCns8+{VS3wkDHylivpC_rI69sZ%T4ef>IH=$(YEETMe5az`%QGw>fm!$-W9v N2&y1vJ)W=)_h0pFQKtX^ literal 0 HcmV?d00001 diff --git a/struct2tensor/__pycache__/expression.cpython-311.pyc b/struct2tensor/__pycache__/expression.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35e11d4dfd49f1de629d75224e0d35ea36a08d3d GIT binary patch literal 33258 zcmd6Qdu&`+df%NV=fUAie2J3CONta9qDB(+utdEqiF#1BC9Os8##}q&(cDXFC~}6r zGql86s^D&fEN!SYce7h_QzhX7T9>MYKnpm9iUz0yY+AJaqcg-sIhnw%(j-^}=wn>l z=wg8Y{e9=$=gd%+i!|*Ob?)Q5zVn^$eCPYV^PMyQG!zO5xTcDKJMljyLHJ90(J!y| z@cDhGAbcQ5!nhzgB&%EpqE+gQq`3MrdQQf_A{sa>U5YAs2*gQX;BOH;O$q-? z+e=b*$bMOrYo|K(zQA*r{ygE@sy_LE3RwklN*L*)i3lM?LEJ(2d^2e}0Yc zVkt$W?38zv)YmIHuL=EK(Oz`^>4YfXo>k;jDxOS;icF7kBApV`H)JsuosP{-N7EFQ zoDicTViT#PA||g-$+2`e5O_I$Q#QX3icwo}H=-#-rWIKhQ)y)`mY!4OLGjiNS)oj^ z8}Vt0m754yV2C5nbSh~tDlDE?k~5+ljoq-qn@kXp2?J0R8@Qa59 zF%gr+_{{9IJcIVJHVuj?S+-OiNz18pWG)?_4(lzKBkx3~=j2FgI(ch)zCSFUGmxc$ zc_aFc47?;~WhEV#Q-SpSETBk8VmuWAI^}3umVnG?T8vAmI6g7YP%9bfC=pAQKqnA7 z2=nNLx#@H~F==TI(EVCsnh>saL#h|Y9Eboi zlM31$C2XJ!F#(jBeUDlR`CX?I^Ig=kbP|Q4wV;7nB8BL*)^&tk>X86XGz0jfld?!H z3yWiD1tEoDIwelZ>B!AQ@>U|kx*bsT^%I%`W`rcCVsb)?g8C4r$e@-4Lq-?N*kyr# zG=TyWD13TQj03YN@jALL07%hs(*!i5PwkG<5QB<1=H*meL7z;-9y4+t?@SQ~k(d*N4!4Hbe$cbxmy5N2}o=O+0UWv}mqKty?iiW3x@0$Gf z98q4u^I8I^>vt6DP154%OyQ$@Hk!VHdzIF+cnrK8&CI`?j7qXXZ>iV~c_tc}y*@;V zk~7J)oay0%YADU9J#3Dl2+?_1OIGk-S5QnWno4JSpPj=P@XS$wUKBXym6%L4rx^ zmQ4S%vsjWAszD6X^AX-}nSp0#w5Bc8rKaOCIbwEaCj9K&rRihZB*1Pp+RNt$t4JV` z1qdR-xKnbByZnM*kS2C_$1$5&v-x%N`AaK zrNCXcUCbD-W^tR8jvT?V93i8$>aw^Rt+a{|p|nt0j#{I%nzFb$xqhlaZy8#!MH0-G zqqN$x9F5i-z}^+9PTnFnO*QMKnE9o8dwEiW+%nawr<1k-uT4`%a^pb&9G3DX0#f5$ zPYF)jtWao}Zz_Xzn=MtdL?yc?+)_(f%I&t4t@4hkb^}ho?6G0REp08!-(ky7!i=Q^ zvt__&TUp*t>8tA7cGo@LRhmi?(30(l*(vQnZQaj+w;ic^N>j1?9f;Wltn^ypMU2x1 zvxroES}L=i(o~&D)n5W-7y5X>EDL4qmIuxE39qyhx%Noia#-pC9aY4dv`gMA^)4h0t>s95w5LWIKt22a9_rbRdWNJyw5LlxAP)epgB77JhtVhPLEb}* zCtYPF*wci8%flw_q`f8Y)Xspk58UX5^4y@6{V4qiiy5-v$6gDS;*$;_?NR9cFaA;b+!hwU^Yhi!fG0@_chz0wh+KE`TAFIurN>9D@pa(`8C2Bf1X^|&;w)#8?p z;r&E;t2G#oBlhH`u_q9Fs)TNNOHU%^MV9ZB4P!6b-pcV8kWM4#Y3U5;R+z+WV{08R z`z%tNkw!4HoYm=rLF<;zA!dY)6w9bKdjUM;JW`yKz9C(}{k-%l(ppME{6)~f1#sr_ z(Km?x8%69zYT}ycv3<`An$I>Ta|YcQtxLgKlu35B&BTC<0lL_RV`IulbcnNU6`Lb13^x*cg4 zs63uX%L?foq^Rgc*b&A{nodHx7I}==ca7CvPrL%XAHN{LdEm8vOW^G4aR8&M@1Lo2k@F$TbO zpmyeZd~z;1mztgz=MoV1Q|V+<5}|K#nGdueO7e7k2CAHrGRo_d!;@jLE8GQ@6cC+< z#te{Bb5M^hRYgh9o`&iQb%#_^rqfTy(Ma@(Q5Kp#n@ptQP-;j|f-*ss;)(F1$N&r| zizep9=v+FvCzh0zm}U?Fh|qV?q@+<8DKaUJ!OQ)gLUqJS?H*3+)m#UR#GpIMg<4HU zj_9P{M9M-L?V~VvN}irjwt;Mw?G)^wfFzX^jSC`5@zt2J_ALAR9GL`CLrBAo8KV>*>6wk`Iy_)Z3sboB z3**v}IZdR85yppDDmnvO06H1{36vp*Y$1&HX2^N~?cUItd4g29-=T;o@S#H?+fj-t z6qQ{F8D2BZmWEn-PNB3R1Zb;QsIF6k+q3TNN*}&_E%=AI*+VmOnAkQ)Jh%suT~wTw zqZ3xH1Hd~8%-sZv33kCv>@!nqftxMWJs`hPxpslw+dvs+a8x*r{u^i;D zB9V?)O~Z-C2sLyw-bw_pvjTi3uFwS&X zs;|TYew%;@B3Klj25Z$|PuAU24ua9CROSalnlPqac0KBQj($hfjpz%qX7EfU=aiVN zS>e#VvSIDhW<9cMSxCUJ*YS*&^eJ&>E=AJ|nDWF#JjR%=#wQI65M#)e;<)wCRxe|$ zeXvtnZJc1`Mj0grNJ+27crklyBG+4qkgRpMV zH49^1s!zw3Y7JsClcvN8GJ99XXCEXLRu%*RjgbseQ>^B2MueIvwTxG>%vX>@xfLhz z0hUyI`m3QOikik$i11o(4@75aZJ^t-|Z))XT z!xWy5!=60IBtVN9nMpVhjxg(`$WE={I-MLeM}&AQ4jV9UG{?0hL6l{NYoZ&Hr18eA z;OJw`3T`w2)?yMnsbn$}YW>FwLa&m5WRL}4s32H~XUuW?vPQoE;O#kPXhs&Ja4MD+ zrKi$(vu+Q4uRbj|{!3_nCRa{5M6 z(nWKP84Qcu(vD%G^_Y%m>R$~sgI^4bP~CV=C<&TL7j@CnOiDz(5RjAN^*9O2Sbh>` zry(xhNKQj2H(F`JtYNT()1HRjVk|!yx}i>~C2g`6MkrFqR?_cbl9}=Xo|PjM9Hn3w zfhp!PiLKz*TPtT8&lyB%u3ddVdAL~d<#|UL$F)grUu)9LtlDHE63fk6*`UYZ(V2jT zs4=$c2Sj7U#nMcv&?hBY#^TC}i6DEKoN|#_A3hesG)ROKFzucdP`md1O-g3ew7zqfLXcD3^&}iX6CfYSL zUx>eh2>_`?rX6scSSL5|RC^5^x99;ziYSgyl_aw}QCE$bxPG-GJ5UgUv$6gM|!T!NsrQDkfV8rXz;gFTon)-iu&60(0dp zICcOsDwOQ2NTt=FHKhxlF3IrqsM6mB&ch1RteJq7lAFk5k8FrLk%88~e3If)KVjI~WbSc9-NiE6`k(PG#3X@d1|5fqksfTT zR!5TwaXvYRzJR$6ogu+*2mLq)+bJ)W(;Q<)ybG9-7HX9}A2G~0EOSu{g#8-tClJY# zMTA2DoHk(1C7Fj$i+VfZ%z{*w4xNa@Xk3|7oQBI73e&dvj-Z6Bw$25Ix!`jRR?t<8 z2}dtbz>3F*`2LX*cgIPq9%iWvj^jwYK)!B{UvIfUHXa*04`Y_I^~9C1kjqslS}$iw%sJhQ;sP>31owAgbVdNuFnhqkR?C49DHvcbSvB&kGXGBag%oHG@e@@Nk_IYlZEa2u+B|2Nc(tS-$v8k^`^a#1G%PSYSXb?^>MZOc-DK|L}|e>Ym*b)D3W3IRXh+@K|>Cc zf<}Zlx8HFsI5(j`mmNn$e~T{i#l0rn+rjYV{B`Gf=#xxuAq`mbZxfl2Ot$;<4Oonw zg#g22WEzt*Lu0?2QzL<;(E6T2beSS>L5~_i))1b|LT$!ddlVdqh16+eX6OUTnp=J} zT@+lNMiwju^0iGHwH@oV9jk#{?QXSp_u@p}U%S+vuidhgUP>2TLZJ0&Wy18KScXDU zkiI5ijyw-xrVX#8@*$65KsYJ8B-bOiy>mVU(SKwRzCBJOZjUmCf-Yy8Ux~vT7Q!EO z579Qwq2IalH?&3oE2)Pq#z9U0Cf*ByZ_>gGH3~9~Pg3km;Xg{7p`-_63YQLr9QJZlQ;#tQ} z<6-lTx0~GWIO1%KVPv&oWNpUTOGXy>x{G-Ey=^qAlp#EVcVEO!`35~}qlJB-(Nb`; z>6;)lhYRC6eB#orZDm}?9=(H*dme=>3i;OU8?E8>)^M(MpW3<)i%I!VEf8GS2^4t_ zuwcoZZ`!)iG`QY0xHgt+8d95v7AKa*zB|1*ZE_%+G)$DuQ1}jNB5C;Uri|T=#>x!I zo@PX_zw4-fzPxWdk z-aAzbRp@BXs--@?bc}qjeatd)XumhMfC;`bgpHk(6&AwlM4_#H%nUkPkgccCeN(5sW!4 zgWIqKQ6mY-ri#pbFeH)?=TIn2+YGR{EaR-<^M z^!X}=K3y6JiGl1?5PkJDl#9NEy z-Hv#2EMLjM7_ybgayYBT?J=&U%_#+;CRQz*nb^3F2!SMlJ`$$4w~H*u2@X?ICEY}oeZ0~ z_7el*Vh<+j?OrqZi|V?i(@Sq-V)h4?_HP8c)`MNSV7D6VF1l<9nI!qp!Q3SJy-E}< zDR|&ZGLuqhA+_MX3bJFd1#c{wh(*)=J~mqw8hK0YWN!*1nB%MiqR?P_jZnX`39sO$ zQIjGufp{@pOrBOGe92g-w=!C~2?c+Jq*%yr(_8kmv2~+y_j=>*T;m?KanDBMf%V1% zxyD0k{o>urVOL-Wp!roQ#2zFgCQ+B86uBZ@%qZtqg( z(pY})fhBLYwsWJlZ@so}t)5F#Ip1K`H&}Eb+0&P=E}s8~2k#wSesxV!!xyr_i@D%M z6`_07oFSBPv^HiiTD%z31hy%$nbwq06*zJyu)mDv(kY;dUyD7GdlM<&^(=A9&cY`h zcif2=B+qk|WU1ZGsGnIMT50Y;qxu${-`&69l)N{YsHW`ROu(}bQ>6)$O~q+Ls11Bos(Z(%UHyT~){F8qR-V~{m`w-`7ks)CUcslYf?dvdA>MxH&Tz)Z z?s!xP^BIpQVzKaVh@vWURa%+)K%hDH&6Ydx@m$4w`Bkcp{=j*!l{qrGMN|b}5E7YT zgPg#v;J9ry6@z<{I#kX*b$2l?TSsoN_FA~7k?U1r9o;B-QOGtP<$r6P%isLq%^$wG zxowFxRKP zLACdV?A8k_Z7Xe0#Tjm5;CF>PFPMk!WwVM?N0?dCkW)f9m?=(x24+qBa+}ss^rtOV;^VVqXYl zuDrGy{oV1E@w`8@;TPBaV$R>C`n!lncI16E?}y$CEl=ir?W(Um>ub-4wk*y(ZQZdF zT6QnH*A6}2|M2MI=+mD5jh^G{J;!rBC)J*l8$BcIJtMiE^J>reqTuu#V}ZuA^9>yv z4Tshn4n3awq(9g2vfA+SVm$Ay{{GDOX0-mE$3540R7L0=R=vYn_Gfb2LXcK35@|MJ zVNT;dC`=ts(Kw+EY)K_t@G#G{LJjp7SJp|zh*^g>?U7;9{30H{kI|9FUB4>J!tV-| zW!KVivm{Jl`|w@gg0qwnz?s**;DK}uNIo*4!5DM?#h6*% zxq;Q#w?~WKXdjQ%42s&0L)zKIo6Q^zv53f|M#D3MyP2k#fs|xrIB>w;F-bc=VN1Ia zoz=V^Alk8Fxbp>8NYYX4jvHk8Xbmq-hGV|Ar+I8>kdiPRgDi|S#we5j^by6&W9uI1 zEV4PF6)@7zVGC7iZiakUNub1vO&aTY*!#&gx-!eMUIZ(T)x%_Vy&8+NfClpz3c4xJ zW6cYIEAT32o&(yNr@q}0TPsP(vHH271K0|K$h|)F6`L5`tFU#=1PhB8VL6QDa8|<5 zNJ%yrlCfy5L(LQj=3E11A}a1LPt8gx$uU8m!eQ|hETSga3L+tgS&&UIbogvQn~j1N z!Pe$ zvv4}134n^RtsZF$;r3GW;XO=xl>P#8A@vA^n1T~(qR++>`t4o9lstwRCf730OE%sr zbvLu=V+axrvmv23#?nQ@v^yGI!u^l5rb8*yJ6+GhxDju+&81dfktc9O8q5}ErJ)II zOGh$K%@kf*WZ7g5f$;=vV^!vP?KXjcBqF9}k@4ln4oxIzzvHmFU7C{Vf@9M0u5*e> zcf=YW33fRM0z$_%@=hiZ7dCm)Q4}&TF%<WqQmZdP&9Wxt z;)7eEeSn4fvYNTB737?SO*IdzxJ_^o8I~Z@5>`Xb<0MuZ#~$T_f>!Dt5$ujE8B9u9@W=VblMuku<*303;L8l z0Da0IfIfv5`viY|-nS+13*~*ypEU)mp=Ti|3PF#TA>se$DN*@t;Nky1`6^4OMtPTl z_Ygb>nJUrANm+Upnf?K~QMpG&_91}9)1)+8WaI*dwk+ioWRG5?s(+wY zT?lFnku>}mL=5911A&Q zx-knSq;Ia9T#0uJNuTM2m+Km$yY1AuCX{&8l zJ932t0nFM^Sg3Mp8K1-Tq78wZ1C5oZRuGxV6m;)WHY{k0QT_=9M7Sv;BpMrpe2o=fVUOZh#=p1k(S+0Q)A z>QM&_Xb4Gf8+Gt0f`){CkpRJl$Zz7|?{k?Xe_i9kE;?E z&O+otCi^&C2PPE=qg+M}1Ff}y)%6EUzwc$<>|jSEJhw6yqqh-E8zj#6FSF6y!wujj0;N^&F;1Z=el-tKUO zCshUv$R@H7I85TEQK-q=&6~^Hc1t1T>lST617&&t@C^RkT~Xy89%rE3#KtfUglN|*8g1u8NEI~R-N zi)|fC<6`zRf-D4LtbXOG-};dqme`t0{O@7+TIbp~mfrqxx7yGJb8JJ8+R$Ufax!r% zw0nV>TuoX9(_mEdGm0Yo+M%QzHX~_8C{=>!B_!pXqH`m7cs+O+N)dhDt34$IxV`3yB%Hliw#nV?uE0}gHe4m$< zRN@eixx15QEzzwl%=Ds3d3$IZT17hy%v*a5;~n-wyaQ$ctE^U*#TY_+e+^lP`SAz_A$v6ti5g7SCVPF4qJE2ly9i2T zuR88phb}9V};MByp+v@j?4+&{hN1%Zo<-Ha5COnXw^uV;Q{;a(OsO*l@ZXQ zmB|(w%9GlR)aB8COT(cJ{d z0ECJg{MjaS(Xe1cL&O87EekV|N>~LBfN*ABJZl&oW64P+IX6puz^I4OiCL9C@0i^a z&g|6kMXt|t#4uYKHqaPbl4QxYx0jg-{s`bH|D1xm2+AaY(lRqPcESNQD#QuPLW*<% z;LnF0SlRZoxxk1T7|FUvG{j-OGLfR?QQI7EAdAdIAK*+WSZ4*h`k^`8awj-#Ui}vY zG4UhalIp?Y*7Us1T*SWMNM(Kh-qT>?5@v{ZPu)9}b+>V-yxc{_29g`aZpPXWRen`s zZQMuLp|l{mu4eu*o~=N$$V`hx7+@f4P8B$~Sx6C97KKlPSTyNg4|cDPtsTw~39bR^e;+dU4vp%)vLBkl?+E zVw^y|KXz)ilKq?6O5Sy1ZPI08wRCN=GJZ=f^RTcCVaU)D5Ug2DA~n`8l^OnZ0ypJ0 z1#eM6Qj})nT3oc4UrK#OnQZi1kV;`;U^U( zz;9B(PT^X>6ZelT15(LQDmx-{J0Drp3Lfr`{}!GZ7TyxY7daL_!6%z8PoA@76_5;* z$|C8n(B!AgEx%>kMoW0TC7f&7r?%|ldmnNwN7NQ*Jpqr81@8uz+)HxaU!(eWWc@vP zU!CSTRCL;!ZNp#dj*XU~^_HPr%R#l}V9}41pAnM3px`rCRTbPt5fsg!5z`qKQ$0Sd z^iKmd_xhL5FelS&;6g5NK@D8^%;lt<2w+J-0D-!p(k?KAH!g$2Y`(;1AJNarS$>~+ zEs`gxqq4`Ry;S&Wx=bN_u@P?ucS%)C&IR|B!671E^bWfcZSZq6_=5hG?$*-EID%`{ zn$&onEQq0|<7!8cSTkfTQJMrto3$RjuN) zaAZ%W;5&YvoJC-`CN2z56;z~on$BJy3ZLg~=A!&%Wey1HuQS<2!38rC=NaVdPzW-A zvIx#YOr#60ndohtu#N@a^H?BJ2>XSgRv?`bS@2#Pd-b)mW9)>rLT5y8yY2_6!>wZ& z{(M->jbWD@I|H&%r6pi&S2M>Hs`XsNLODbJ6IL;^81dGWNN38wMS{$rMRNE`P!1B4 z`VYtkuG}DCH_@j2Dd=e@Z;Ez}KQ)RQjos^w-MPkHYU8fOnK?Uv4{SyLy|5Me_fjo;pKjf;cvJN@XT8nE79C11 z1)*`qW$>u~h)>{Aj?__jR~ymRht2kX*~hh=8$1thx+Yk?&P~2#FBdYDP+5wBY~|JA zwAW4}pH{!3$MM?CsP2FnN_MQGXR%=utiastfFg7kCyv z?!DJo>1?Pe{~k#+yMLj%GX9u9wauiZ&~VE%Sw<;|54?^`G#uYTJR6SHD*n|^at5?Z6%K5z$V2$c0vdZ0)F@cry)MY{n`;IxCvd@ z#_gWMhB}4Z^mUz+_VUNww)@ zE_g}}LWw?Au7VO_G2#;OU^jl0>-9KLc1-7CWLAnB9xAdyB6Jcl#lf0&L&MkzKN*D#GF>WYC1_&Fku9Q>Y63#-6QvoEONgG zEO;Y0g9-g%T@c})=BPJFKtpdpgVB29$I43g4{rbH_Il7bw$ru} zZBSpBb;@SbzU07eVQ{FA@Sk7Av@&}z0n*Uc2n)v4L2a(l45@s&;o~!`{06cruOomr z0gb9MdaH`is6z9sC+=$BzDw=fid@hbgS8vM9qYjzt7o#o9l2ny8tl!wd)e6PcU~UT zwwqtZISmS}VdDtfvHq9oaA*>rBauQN5}8R#bJO%(9f?qnXfc35O2+WneH~qcH%jo=c>;0XiFc z1>c~6*)hIJPqfsf&_u=sdkIek7r44YDj(l8MO?5@dzMZMr_*P~l2_r{q7a1?s;@Ek z_%XH*T3MznKc>J>kzVcd@6t3|c?pV8BwQ}8JT7bxhXfH<*2W@#3@hKJPe0Mm=czr6EY z)_Ja2EgT-s@9ZrGTV02X0)nSoT8bWe@d|aVE6)2qwYIZZMIU@ZO9zgE>DaF}4;B6N zB_QnRG;^}-JUA%0tBX#jYsU(9nc`*!pSciq_Xvwbi9A{J^2L=`@Bapr1WMR!2A0&# zM(WPxG^N1t1$tmf)@dXYS8l9k)ZU|N`!J=1x}gV_xO!u4LEUv6rWh(!4=8!Fo}}2SrLPvkL&abhB}DMFxqao9hCvT~^a@Qo zR^Ct>2Z~kn>J#cZ)!M$IpI!n&L)%L4{k>}a052gRyl_%Ie9=QCb(PpDJAKUF&|- z|M-pUz-hG$Md4FvP_q@uzO^>?==CSn**zm_&$*(2Po=@(MnEJz0%&;1@7%q1_Q8pw zi=Gc1eX{f8*V!|Rv;_NlUEZ}TMFF>b-@eCYbnx)UNBFZ(7zjVwTJ+O*pKxTv!5^tmRx%G>b>RK)tuB}A~BClNv^IQ9hE_!&n zUA#Y9^spGO(71i&lDeh0SjFCbG|~pt`W`K(PZ&CBErb`Ph5MYYtB%zUVkcJ}pLG8M z6iQDOLZ2J0IaL&J`{dv+K%jWYi`@_WMHfAVhuNKdwglBq*GrC-+iSII2TZ>7@&t59 z_X?rU;~LEG!0OI}*B{qrdrql4UM#xs?P+Uo(L?vaBkWGZi4krDH4UyKYd|z^`Of~e zp+`s8c0X1=p3e@vsCJ$%y72vJyLkUb(Ss+@Pv0QML=}HNe8#~aDK{&E2YpW0v6Tn` zdF(NUA#SCC&*i%0SOZ4rj*bEYL3iTu@EQKVV{Bo-h0Y^T>7MW14LFbokB0yzJ@6P? z7^rid;4A?T#1e2T4Z7W~bF5stenrRL^USPM)X}x16HJ_iY+*d4)t3?11q@D2%c*N@Mt7TqGmX_4$1S zv_IMYTYyrvzX>nflkB}48V^WSa&XEprb|A#O7fFz%`g1s z5`GnMJcRG~^*%G5B?bEx!13B8;kIME4!S@Hzo)D2(elJs@}!-Qy?&pe5%OpCmV2im=8)fGR)w~z3j zUzJx;=pdN|SLtRZoB>BxLRaRO)*n$~KNX^>jOI2Q?Uzp2c47XrLD*s{$Kr4>koZrj zavjjBL)tcsa`T%NA+W>gh8(VtV>dGYnJMm$Y5d+2DfY~fK~JU`BQ*2huwmgdmsrss zgyNE%P{C*!kdCtRuzq6NjPCGAdZhV_Jq3Rx!iW>97~x;UrLB6B3X};^*BigG^ByT` z1Ne@$y)WAYPvE;3?q1k%H?F%Im%qK*`y`xoH{$pd)lFVz{Ev?XvI{VMjmhpd+X1Pa ztk+3{q2K7Xx;T`%voL)q)5xJB3nC*eb|~&o2sqN}a6nkFcIhCi*`^UA0k4Www`Y-; z1_F&Aq~a8p2l6hc}*#^mYgNCYJC zzwk9h$kmo_zo4WPe4k+ZH+VoN>i_c2m$J^6AfKN)ixv9j+FOER8ox@f86_E48e@ zkK^%|crJLshF~saN-oC}(KpwV8D@zCb1hd}n9_c_m+9yn(bR*50N*t~fgfsMom1#C zKj4R{z~vHQ_A}Mf_>JJmH2aZYg;p&K+s%X#ng#F+a#8vrV4TkV&OY9=|DDoRO-?7T z<0t@&Ld|q(=CMqUMBHJNY1Pg2%9Lkvhf)3;1w{k}-#4S!%?{;W`HWt@{8zyXRj1FK zeFb%M&6yFw1`4@px(5=$AxRELaeJM^4LJirx8Mja2Iya2 zxSIX?e@}(oS^HmJXwTaJ@0~fgC${dW%4KS z`|)i5m7M>o>c6_^eTrJAR(7matyMkrKl10dj-b}s^HeLs#UN^}Yg@{!gx7l4dLQ;b z>d&{GM~b?Olmg*mHBvOTFArzi_dV`?-23zXkNflOmyx9LDkVXf73#ncYPKvNfX&3e zI`rUJc1u_d?WHCqd=7{|mBZ!c)dLS8C{yfbEd0WnQmSuuU__S(OT(Yh*DoqWF>3wa DiLDw& literal 0 HcmV?d00001 diff --git a/struct2tensor/__pycache__/expression_add.cpython-311.pyc b/struct2tensor/__pycache__/expression_add.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e214098ac501fc3d3d957444201f8bf14deb0861 GIT binary patch literal 10636 zcmb_iYit`=cAg=JZ%VW**?L%3#_~&)6)JwL?a1;&wqwUHC0V;^R@s22IYWsuUzr)& z5lt2J!f3UD(0I3LtpqN@A3@fs1JnzoKzKDtw><~j(fmK+zco2 zk}bg}Z8Nr{ea4;?W`v|;#*uW+IC;vmCtOMQjGLteq&+hpmUbZRo$<1?GvQ16XZ%RJ zWV`H{@2EE?yXQOW>6op*p5!FA?39Jp+-4rr5U3Y@F@3GNEonjW+!JOt*|_tZ&K)7aXA=ZAR@a zQukFZ*b{MJ`Cj`VqsXcnPp4EdmR7`wB*jy+Vk|BvBsJvq;y#>-Xmjd>7)goQlp2Z2 z;*G2(X41E1C6-Nyw<8N;I@VTEybu!?(pfR0$l`4!u941&R}@q7ZL!&;sAg|O6A@J% z6XzngSQ$~pS3(wJk*Jnd7R02i&7~z#n~P{-G?GZDNXAuYie*z#4c!*ws+d(}^rY5X zEsd_SUTJADn$9fL1|!~xL~n|ChH=OdB@veuQB!0Y)qY2d7jzh&b4ofVr>@GH;kpvZ zWY87E{YHkGi6jj7Rr&p_oQleZ<61V8kb}J8l^fHCJAoz&-!4pPN=+-2G;6DdoAScR zNXBp~vX)g+QA~d05|APrdj9ti54bqD$V>d9EzcEkzyA6Hj~+K3pXJ^=^-~+iE!r1_ zyj`{Dg+<%4H*cSBv>>(Ynng$6UPS5jmt_9}H6Ax=Ejsg#Jb#N<9C0qs=N*%%k{2k| z+-G?wJzbv4+wy$O2CEWs6Q>n9a#O_`(Cmuw6ivCPq|=(1N=q`PUlOG_Te}Dr7^VX& zprUGMTA_~pj(QNZ8;-apCso7CY%!dOtD50dv$0tGgP=?Cqm1FM*E9m8KipgmgTHC; zx2Ue7YGGBAGsz#SJtQ zN&@%Nw`3(jc%1?eW=I*SPAl=*I1sM2>DzH_4k)W;{D8l97qk#G!qH>=1-&#g!zF0p+K)Cj{T zCL>zVZ8+3yMpg`WI2=#KwQ$(*#nrGPXK2eyhM+AV_a*37L+s50T_dIkz-?_I8J*^aO2 z0{)Z%>H)DYi;JQ~Vhc(*`hd%G^Y}KeS^g^bB|jbHmECyoC9mwomr4~xHXLK?tPBuN zXO*bjg3)SjbfXb&T=qPvq)97q75}E?%S$IpLO<)O(NWWkh0DsS;g@6;d?P|$TQyv1 z=DyWv!!V$O#=*w&HWF6T%=1TJ+uZx)2&Wa|pSd1v zZ#fDZ!KRZ%v#TNI=U{yHgq$+GZ&Sx#Qxq@;fz6%rGHNOz65|kct-6LvuA!&?t)ec8-8KZ3MM`o!8iceM@# z)Gw0A?f##Jz=~ZXO+65-^J~7ntG>NuUr_f2OG3~bNsZK`B1zdY9n{5t zn6z3Vk;W!)7F%oLP% zr`5N`nXkifgpMT4ndO4l6!L7~r{ovum_qCnUJz81++5QWgK!eG86Exegzkw)l&-w-*`q%nGt9_wIxpLox-ZxPawp0S$%YU*K*s&VeQ5-D? zUe*II7hbQp0wq^(rDLe%9C|wR@}s%(&=Gy;NLkocxN`q!#no4Gi51V*l4s~?XJ4uR zjdJIkdgq%Z_nQ?@SC!*klTQP~#p6##%7J5g;8@9htkN-1at>5nJ*db~=5*PAM)#j7 zInO-x4V7Mc1!aB5bl#QmBW$$6ZA7DK+ki%U!{KCF$|fl74~NOk)^gBcDIE=m z6*_PlUht?|q)HKJ(B!HN;VMYFA<#LBk!FQZR3TGPyp->qM#$kBJHIRBuatxIRJf5& zCz$hBUZsa*L&_-erAOsNA{+*57{OD!ZCfGr?N%+0?p-Rt?-V?0ak$kqPcSU&UDIR%r=&M(s%_t z!+So*m?`zj(QgLySEuz?-eC7t4i8(zFmJ!W7yD^O^c_DUT^INbq|g;$Z zs&naVLPAu7LdcS-APRywAN=Aqkx)ciFi{hJ#97C)dlK{8C0^s$SKFgj`X4MH41MkEvPOuJB zsvaFOW8{=ypHPagV_rp~wMf!qQd^&oTKtnK|1pbo*YukU$fCk*saOTzGu}LqCc0$} zQ-rup8Ac+C)49NO8qFcP&U$xeG1o)LFGdnD5s4-p{lBvqbelf1lL$1_M=>8*Im1i$ zW_6RVU&aPAfu4@;CVh7(KcTcA!r(Iok3bK`8cxT^lvI^m1q~Cnp}BD7!{=*J#Og zfN|xKwchd7-ti}Vx%YtHd!Xz;sQV9=oCjMfmR;L**Y=VtXddC$dJnDk9{Rhh-*o-@ z?XTZ1_fG1)lV$&L-G98~1kiKt(f{ZkE*>p+kLlfGcSD68g?FEN23M{XkCi>+x@Wv( z8)q14AheK@SoF&fD2!r3NQXI|*#ZNSti*3&_c9}D3;-6Sq_eXyj`|8Q zgS?Wa%le(jy8JUC~0mzahSDlmK-);z?pAJ*ih#pENW{;P{1$6CYzxRM`aKLHg^j;)!Z3Y zwG{i^pe5=gsR)^*12~CHm%%%Xyx3BOixKH zQ_cwp0K*99^e+9ReyQ`zJV@C)psCMW|)OzJZW8Z$>bYd^t^8F>?v6s-eep#vTH@>*H^5Fj0m z)BBX#D=DgXz^uoJ0jo!B$BguWX4OE3dEhz?pjrppYt<wh1uBTwyih zSTbs`ww4n_&*;|+<|;P`8ogwPZ^BP0*fcb30T1fqT#uwvW+xhqjmc8 zqG4Iolict$R*8NA0LNWJVET1Qll$FSoX+)3)p%PpozfzJTyVFLCAdSUc^2F`g#`UB z0>Pba46x|e5j-FZ&(~k=0iVCHEe|#WKD}02h`?-EW-njt$_40B-h(iF=2WsEtxDL1 z*=47F#DwVA3ua7cpytk+;h#p{Z&62`gaG1m_D-QG{>|IBv}5UA@G_TsY1^7-Y}GUN z=q;0pRqd9dX1K(oY=R!;rzDt1iBpP2oSO{xU3zvM0s&~oV2P!%u_*IM89uUtI{PJd zyDs4DGzX(3k@+hs@I3@0vi;&`9iMi55_lNE?-Gc}LflQ^rwZFDo4X&3{c`LVp^rm_ zOK@5hw|~vOYt_B0_-@&~Uw7}XxVzWf`&Qli9$hNCkLvED-#TpmiSIZF%!uPMkC&Gn zy`|n`YrUsedr$w%!GC!Fw}=1vaJl!Q-g~j^zoh#wm7JIUFVwdx)Ti{>;Lgvke0t@R zHy*xGb>QK*v@YL~D2)8x&3U`;A6n@vdxmw-aLG2z=6*W3RrxWd!EoZ#g$-o}h=>WB zK_fCyg`5c9Lc-V^KNevzD;ArI_TzMfFyc2eu^(~r(WdqjN4VR zek$Bh$@-~qJ4)71g*#ZXeyYwces`sF%YEtNcyU1Aak#Ygu-OPO)HiqKAz$2FEyvSD_s9WW%Vq}6x zIB30C8Px~Izom@tnz=c?kAH%XXAYWG7NGi5w4?X18QI6shT^ zoR~{vIUest_G)&NfrIe&5-fL-$!L%a&cMwA*~9K4f9%~Y4tD@)bP$4o0Aug+$NZ<9 zGeD5T;qF(jn`Ad7+jFzIHdWp4b@ltI_p0hu)$=d?elLgX+WP;V{U3)o?!Qx_e7TH= zNB_)o-20rw&2SPg*<$<*kGl}F%?mTaynV(#@0fASJ7=6cy|>3)^X?fpdv?S;^WGUB zo}Dvwl555GcA%&Zj$Qo=aK7Y zosu7Utv2pFCk1YEQoY>vvA}Vk;7`A1+9i|-+%2Z(PSn|s^cJMMrLD5>T6?)o#+z+;BT5~zAUoy8 zYLBZ=Qm5>`w!`$5Q*N)mMc=ohgr28h*`+RE^-A5=TAuQiQxfr2pJbcjf;++;AoQ4c zF`T?AD)NFNt8zRUjn9c;ae?w?mBhT5yef-IB9RoMamrWZxSCKz`E5d1qlx&S*Lw~b zYIJ@fCX11VMy^sV zb8^z?OXO-aCec^aEwn$m7f+E`6i?!O_=Y$ejib7&s5%l0tE#xD%Cn0x^y{}oqRF7W z+5hh;Eh~uf$PV$^JrAO$ParSCZ9;5?x3p6r|kL^U%VT{i&gWlOt~o z09+c|Hxu5MdS>6z(9j$EF7J6`Ft~d|UND$yw&D_MFx9y+^NmZFet7xv?l&$6OHM^j zE-LXzc{!u|o@hLyYg}F3``Cv=!Ixi=Q0CTa#m2eD0y%>+8 z<6=0Th~Jo}fd&m@(%^3~Z#eX{H8gqlwdv5<=&gkErn(T0$U~SAi;?91B%53{DH_eeg&U=2 zP&8C&Nr<&XsYeBS@Tb!H$#B1Ob9Fm1{cE=N>~O)>nYXnUJ@uJ=kMNQ!@!=AGB|>9U zzU+8lvQxyfDa;SE@;*^#&T4|I%iPqT<7V{PzFgQfSnIwvsohM&oKSOv}$_>O|14G}2Hu z0hH|FXgpc6UjbBUgsIs|fZq5+k;FW$`w;a&p?Puuf9ekbFk!ZE-jD4^F znZ7kovpMz7_G`s$X3fw|7{ zf}bUd9^X=v(ts)~>yptPnU7{L)86M~4w99d!CaUTq#cqV+hvDI4%j67ZO4pLa!8P; zvI{cD1>Ws8(jKJUNPA6b57ItUn&fJoDeXhrZ%Wr89WbRy=GL>iY>>M4nFgsxs>i(% zveY3tB@T6xtZkBdnM`elymT4)BtKiEwo1Ak=~kmGl9g>nJuZ~773p@VPa;XV#Z-R> z(p#m@+Z?2~Q|^FfK$5W2cqVzbU2@1>5=ZZ{P;YOpb>_Y^6x)5L-)<=p_oW!BWMT9F|C1VQNYuE`xHCOnL)^1ZfjGw-Sk`MI|T_i~)p5 zZwMzx#EDsPT3G}`Ex#rThLeoXsiLIMgR5ci*Kqs><9ozYB3HvoI0BLz*pb8{bPH%W z;rT1kxkZRbCN(7>Bohgy?&z{@5e!$7W6}92eOxJTKpuo1)HB!<9HbVV3&&J6iRmz= zs)$=d9z$C~iA-66rll(|(P~-5@u2}luV9og3Xr9&QoSb%#DJK93Kzd2h8L5GeUXH$ zM4-gbfa=U2T}o6MMp9Ro17PvtD6~3V7Kq}jBqGi#Nuy| zK$S_U3$a0F1PJP6Nx9CEDk1iy{5|w#G|o^4Cxb4f6*H$Kpiw2i<@9dtd7&d%CA8wIrOoC@L!p$P*vMcx`2*CYz6Frs);qbT z&UMb|IKjY@<12`fS3W#b@QRvOTo)X!6AuGD>kiz1=j58joTJdxuQm0ryC}!Y)o;yi z&An3Cx_dQL*m}GW7}Wx!dH3kIYi%9b=L>DSw6XHG5Z{{g^;0IoRY@ zXenBh88gS5iTIq1&5}`rp$D<86&0;sO(t*vl0yljxg}vf8ZQYDj%?X7eJ$uvdV!B zeWwe}XSC)s1^>9_A7AGjj;#zV@oP=pIr~TT_jyg+fB&p@VEpr!v;*f0P0wjf&t;-T zPt(dvcegKXFV=6#yN%zbnd<`QZU3siY30DZlZEVJ0p?(wXn;Oo+SUAY|(e7$*l z@AuhY#=53BWf-+8`w2V?P}@k-)GCdn#Geu7K>|!a|9%FKDO68wOU&FBO8oi%kG7yp z4y^Gh;}RqJRwf=L-o;msAztGHUr*lN^L-8&(-oiE$T{Ibuwx8Oq0vrO;k$MXjCPSof?HJE!Sl%Ya!}GF5^|PZYM$i9%hnng31nZPCV5X(A!@L!&6%cBZ zS(U>^k>6~Km?Z0saiQ$W7BJ1vGJJD=pV~Ie4UMdP!Ib$C zQ8|ywj2(m0g=d{s-Si|9s1Pw|$T6adMMQXrY^=6aph8BL{)A8$0iY~6IoO3pK2!_t zUd`PL{T?)VSl6J{^+U(UGY}fPveya?L9HPOC7-gqoWDKazW@GAq5Wi`?vz${D(^f6 zvKQTT%k|lj?7o6~S8@9eZTsPMhfwGIngjTTz&a;5oCd|27KVA2yt+BXpF{ghUjG(f4+A7N%Yfz{z|Bh8ynE_;JtvCmcesI67fhlX~+>h@fjiQ1z?;=@9b`E z+fczh^t9gA$#L4zWA`P=Z(I6X;aE@ZuLQG14L_>zs!uocSD8dBl?wpn4!%Tfr!Oj6 zj;@Ux9n=&5KuG%lOda%RJF|Yx-T$;624Kik$S?7S_c0Bd;lHGh(;NDzF3w`dVf0Z? z{NpBl3}oxG^_qL&DSeDoW*o}+0|lXS8sCB_dvMcv@HjPyO0Men%z#$W?^)`<5LQ_V z29}b`e^T&uYrgKhy<3-pmDvD9;!ccfKEQ)@4LwfYSp~KQi?2#WE0%SYVN2{nb;gyp zS#hyx?!iXaBw^D+)y|34Txq{K|4r%zXW9Vwrpdp@r;e4mDF$Z_+e~9$_`;bnCR4=$ zl9X(h2n(AE(-tJ8KF5QD5QC(=m`D&r61lyUy9=P?#a4~Xm?&DI1z%%%?+OuXg_TYP zgH6oOC*pA5L}VDo6#6RF%h>Ick^UPM`zAn!`>L_!y{T-=2i=9%pw=4v^|Pys+R(|u zo>SVMQ-#Jct#K^#{F=XMWu)Nm()?ZPoXye5z!G2d_?Nr0jhd$;yQp;!Vei(x|AA+J z-m|}0->~GSf3zKY_MxveZ*Sc=Q!3m?K}&xBO%-U(D1(o1nW+l-X{%Yo70gZL;?XK4 zF`Tp!$8f@!rw^yT`QMa~gE~tG-7cXGwR~MLQPUNRgo!RHwWP9--uhS?rn)NPa#I?A zfnwBC02%IKpy|uN@Pokc{hfut5iM{ebG{-6-P!i6TXXMt8ds^ze3V!*b#0>Wn53U| zZJ9M(vs|=dy~(&&Vi^8Ewnag0xXkzxZz<`q^_l5gQwQq{R;VqtpKQ=%EFnWBGXohU zrM1drzc>&blm|uTdIy&d!a^#T79~?~Htc(Su$z`AN2-C*$&|B(9;vSbU{U!@T=Eys z<`%VGhYP(&wB92Rd`F;q)(w9){U4``ZpgMR`K^N=_GYCI5`}Gp+O|O?HTTfh0)LJl zhN=qqKRN!yFzKtq`Vuk=0%CeJu^#my*UD_GD$2(|AgxQkZt~_O1J;0Zlxd2V^Xl{qKLgJ@he<<}=0%WueUVkIy(Oy;+?WOEg_w$>Ed zR{M3DtcgPE`VO!hGDm7L*)1t^0`8|8}tB|DXdsozUE2wH$8 z@NFxra);Fr>xN;rFDL&Q$m-7lpvO0Gfga7@U+@oT{(;Q7qQ7xDL>Bz>ybfXCV-8@o z;OT(4fX&Z!o-eus%f2t&+a9>LA*{f?U2|{Gv)`JpbtU{^N5QvK^X-Jz4&1c>*2!~) zz=Rf<$h#-LqU5o^mp&i<+qu7+D{MVm2%OUb=ko4zMPFm)#f|HP)(HJQ5&^91BwaI4 zn26tk;2Zd}RGo@&Zix#Eur?#)76XSEZZ@eViZ4!AE9x%>XSGRS}v-$o73AwQ0joQR>%l7v)!;G0c znLu<9bJ`IDMTkk+GsxU_2nmWN5bB^idgqypjF+8QdT0@wi36~^4^XMWptxg{&0q?b zsEKTVgYroY)9v_s5j#YdexqvL^za=&Xi_P^sZ3@7k7| zU>)L=C_aIO$Ts+lGK)v^R^>bIappdTiT*^2%?9|b&^7p&kH->lM5rh5Kn`ZY;6WFa z=0}Fhw;a@eIUf2|kQPK zOG93pE0(b4$#@b9mAs+Qd_r0z+jFTd6#CI(IA**-%cVpF%z=idWS^p7nSOc{z)hGd z0O2)>L`*qARU9U8l)y0pCkUJ(aFzgV(v^QjfQ`}=J-tfcH3Cc#{UJU55rH`Zj8%~} zQXv&vp*5nA=|zbXAp0saf9hSM5ujbW`Z*>{rt)vy<`7^zArNR24y|(lxh<=GA7L+0 zPMe;dX&3xiVBwaFuv;zI&kKFoeh^9M%T2ReEpUE67`=Key?QOZRErt>UbQ^Ej2&5l zO5RZmoOQw&#NP(k=@Xo}zI6__Rg4pEwP2f_@~of^$y_tbU2Us<$#5CQvpNRKwSKN+ z@PfFaQai<3;B*RY_dyffKAU2P#$Eoc*juI#~F+iG*JL+jW_b#zf z5xS}GV8|(p#_P^(W}r32fS6cka4;nxr1d*Kow5b{CWFqBO+iR4w3x`%QbPz4rDrFn zCN4~dUOs#N>}#c_(A30>7hgCVnwma)(a75pdU5PxXkv2w>}wN~=gn_iubzGR^o6Oj zP=9Ws7bZ!iFsm>D53yt7Oi*<^zU3`Nr|*kESHanpYc4qZzSwoN;5_=-nS%4=nzIZ0 z8&`L6$4H6(k4>1+t@IY+5m6ZNuHeB+yjpb* zp;#u#1PcLIC74x;-NHM`WEYa`W#=u(@?W1smN??wD9Dz@{U`_tH9|uxu~+38sQdjv zO;TqlmQ$6Ija5R6i7N3TQ)22*@bXHZQSg_R~>lqdKS%2U6|;q?#FB zOmqe{Owk;Ii5N88N-ag&$gHG<->Q;@X(-B1fh-bjB6y(X+4g%gh@JsQXX&4thH*8^1RW(0(S(9 z7PUHpXUPC~?8~V^x(m>GwS# z9j1Ph7MkITO>OT7e;&-9d3W&MU}hieyi4(=n@e%cvpt*4y?B4D(0B0um3-e3t@~)+ zb2M)|%0{E)qy1h;O)M%Ad2_;lgG#J~f6|En1p%)V&kwT`HpEqny<{_D|1G-`sU`^b z7C*<);9AXF*X-ybyCi$kTYkpc!=vONn5^40*JQL}xK4&^QUtD5#J|<31PALI=S5>sZ_$F)x+=xP&WI9 zBA7)gl~Tczq$W&X7{4$u&%B+f;0Rn6P&*I_3A@gG;yS4n65?0}MP|W#OAn()boiU% z>xsqbMMPBpfb_vhbcb1cs^Q_LrAyJ_hK|C>a1re?6Um~f=m*_Mt9XJ!QVt;@^JiJ@ zpbF_T79^s?Fa=X>tR@_8BRtqCnaC6emAiPU{0U*Wf&*_1ru%}7{Y|DCUkt}6^4u_y zj)-E=uUi&K^Jf-LC4;Aui_A08n5^8Q{GSl0NC#b5NQCB?Qxb~{B>{CKI`Tb2V>ZzM z8yyTW&ck4gOKmJoR0*jvm};1uZ_u{AFxe)-Uopfrk&_B-0vI?>e0CcEbfmqUvth;e z;dH^#qd9s=!nPNEZFzfJ(Shg|K^R)Olxw;dS`{9&?#Z|AK{y0lX|4MT-a*Yfn70oW zef~^;vAHWV_*KXD4@Qbk&lh!_1!regEjW9=*m=0%Jp5U2!8uyhxXH8NdS-3-S*CAy zlD^$JnzwCPM<_-+?e9j4jk~^V-1ney-)g+jII1;{W}Yv);PSIOP=4Jz^S9<#p8v9? z??Fpn?xjM@F0Ex(!8xEg2lCE=V*AeAwO_{nI$r2Mr1c-l_aDYjYd?~SXX0N`oX+mm zONI6$g}S3!-O;@B=(@jL>F)xZH?TB&r*j#n%XC*^$;$?rnVv1egc!Jhzei*qu?iE3 z|9`ZO+9oYUlY;UG84-1ZN)(?%3@32Y6%Lc|RThw&>LQ+}o6|S5AI4~I=(v+Nn2FJCV7M^ZDa+=R7Wylu~ z$I}fx0DR5}l2sa?BtwkAbh~7Viz|oMD#Kfc$2qP)X*?5}Q?1^M+DV&;-to9r#xW@pR&;k%*#jWX;VdK$IvPiLnN&F&#$jQaJ+uL>9DxRtW{x)f;1N3>Gbm0i zE-X;IJn{N!=@9jCcA66MflU86dQ*8#z8IRABp^uMG4EOVEeRC3Zm=H@(^<2<;-#M5 zJ(ua2)1@AEGG?r2uc$wrrYF3lPcD)?#i0Vmxv5aS26FisCusIEmBdsmD~)}Q@2z4~ ztsbwM6)%zp&^XGYmx?6hFnoI$R9T|xMcq$mT4N)64$d}}Pe8MQt&zOtk-SJlt!q*G z;i4yqVmhf*uDe1bxG@q`<8`PK>`WaFYqNu1C4RRqpGoVSOx2I$uo#YCz(GWP5mSDm zu!O^U~2J#=1&AofC%WU^ElfB_~e7 zqokl)`759*e@*x;Y}i6XX8?VYsJdLX)2osTkY-FNQ0^G>K&d5uK_#7>2AD2A)VD)N zsuX3kfmNdXJD?|sQ|<(SlVyPO){|Hckffb>xE1D!wl!bi&cWq_5WInAGF8Ig0FaT2 zF5il)JAH2TW_!pP6 zmy6DAUpjX@aPG+U<oYaPHNdd-LqK?&e(ncdp&K_Fl`1^zOEM+aTgKXJ?-M zAi15}n6*#VyE1&ozwBSazjYTrMFgg|Va0i;tGd;<2vK5_ zOg_Jihj;jQZ0`uoAJlawR z&%T|c^MYUzu^5i^>2u6jw6t{~n~TnuObf^&w`kwSB>i(B^m}OF9eWO08?Rb>BKA@i z_v_`K+_a@_)d;dXcJS(YbHGFGh9CPSesfuZa!3PVSFP+Y*P%3FfUv8q)>$KfPs+Eln)avPuP+(0FoVwq<$ z8QrTzA+xEzjl8I;d6i50h?e{>fc`CvANIO;UbywbO5=|&+`f>xfHOkNudQ=#2i7_O zMZI~KyRt{}o?iS|T*6^J`YP`}QdW@1?wu<*J2YoU-r2$MY0dQ&-H3`?X}s5!Yb&@1 z5b~L4zlT_5{=5G7r0k)0WA|cuEOg%8ONC+0|2TF#mchUE#&R2U^^aGm!Ejj;lT@g$ z%HI+otwIkOJay`nJ`hALg}`4@Kgsy^?YHRZL)ty8`5W%MviwSBVBPK&(8ni&Eu7%X z4CH+s*_R78QL~A8o5=nv2-YO!Q+e8a8$8?%B#;Wz_Ccivv(cNlR!cAexF^m77=e9d7)CO6Q-b#kJ z)FDG48`7G|kg5_<>V1|F7 z$+8_qVXEYtPwBth;>x+~EALKLGW5yl)b}EEuA>yt7kB8&4c+)GB2)SI1ZWBdoot%u zvvvX*rTU80M!@5BSyg^bxPK2&UeTgxOypI1L*O^mwm-)M+D897NyCck&gk+eWOBjN%fRYoO z4GJZ5|0_gG-3d>us#HIhNIJ{)|HUjBK&rAyh~ydRMwT^aksEnD+k>$=m)H)JnSdti65f~GBe-3sCqaLY>UmV=ub zsakO8EgKJq^yYBs(AVtEH|FscAdj5cYxh$f#~G7 z3p{=XG~8-IU?0Dw=xbgX&ocr$*t;dfAKk#^ZywF{-S;C{RMytT(I=G@4+3pqQs zDJfu{Df79lj8>Gt7Sot z&1L%H?Kj=SeUDJ4y%g&>n9DT50LHol<@fsq`iu9)yKaiZ1D{b_1?I(GTjcT{tp_Nx z8)feOj8bT}KWiL}GJWwLl)3K_%Cxu4J`P5i)^X@X`TOIJb3%K6tVR@s!Kf_F$WlU@ zO2(4o(^Jvuv1lSG&yFT%Whoww$C9yVED=pgQ;BIg(Gv&^MqiZKkLY+JI-8h{O-x*m zj>=LjiR>3}iVNdY7wLO+b^=#>^s^CEM8&kVM^RC1Dsbu=Rh>Y!JzA*=JUA7-AVo)~ zr>;mc6*@kZoQ|HGiOC7+Ja0x<^yS#ai&9+s8JL!NE=eztntB?Sr)MOYbu2-3CND`* zBTrW}A)zAH_7UT9&&+i^{4<*9p>|)+e_EbSNmJ*fWX^XsNu^E8IoEUJiDb_Ed~9Y0 zoz3}Pn4!UoP2_y%q_59PQ=?MOeSUUk0$*R6!ku<8S89yIG|Qd9Mb}I$d5LcFF(2UL zJTg5)H%msRCugP;<4I|RUmw+gkG>u1j|W$9^0puepdbQKk6~xraZwz0#hsG-vPaD& zj5*>itXr?-i@R?M*7Z8<$5l_>)j-^9E5U~nf^q*%VYnzB01O3l#V=1!Cx@ov>eQqv z&tr(M#3p7Xf+QMm85w$V?vccJO3Jw?9o3pnZKHD{Ko821ye-U$bHWJD_z@RTi6b@j z6q}<&G{#DcPN6RO+=%rR|rx0svw;-^j@L)Od1aBvoY{!yfGh;oL+R z$-J;tBt$CbPu_fH9)AyTNf^cA)ekyngIIzS45DyZzjaP{PaKkIaM7y@d?hGzRIWtI zI+iN4RZqjX7k3{J7c(zx2E(@xr(G3tHGUSZm4FPq44bT>^Y01r4qPzRlM9X~Mu2A+ zl1SCs>NDLtNY8nZuxc0fvO04fSpqhiveZqzWj&^rzAYq8qj6b3$2jPq_5)mfU%vye z5>x;6X^BsMl&vzhfPqKi(zWOqP)G6-Fr2Ne9Owwf^B$2UV7u`vQWV&V;GTeVcOnYB z1>D*NbQj0v*#x1@*u>S?^#p^p$?2)-`P1Rj=V7*I$1S%YGQ zTk|p-9eouHY9NpGXKm#8s06e~5Z><0Ig(>Jw;Y?gC^0C=o3iqYQgXz29HUW^TN*~LQ(PUgNrBcoXrFqlwd>J)$pirBoi#l#P#L1%{wxJMTV7AX={#p z{b_1}kF_8h+>s9MV08|)hiK7F#U`bZkz8P8WO6z_J3;3~BO_l2Aye<5^!W5BF8fj3 z#B?$-G9ovju$+g}8@Y;p_-16HEy?SA_!@CdZlR==64q>5Fnpb{u28ZY|A}QJ^ZH-b z(U5jDY=$cr#^>9&TpnlT(!kwATLMn(Nr~Vp+H#0=1Ifou{JeRPU+}26O6XRJ=~l^> z$M0-hKJo6^EdeL?q*8R&(IfVRjToDPQM6V+Xi({R@b{2diOxkK?)Y_(9WYqoxXVm9 z;%>l-LMU5LZENY*&kL6+(*bY6w?@c~^_$Mdo@z216iX z;dn8KSV`OoL{O3oKXqw*B5tB$O|-&wDh7rGJcuMuO7U?pG@Sgr=gfsh_~Pbk15jkb zWVtZEI+C0oVFeB4ip;rX0+?JeTUC0VQ5TJ^dJyJf7_JW|cEUL40GZUW*_;!X{@PuP_R2)9*woAW zaF(F?rXk~4g78YJ+6Id{H}E1oKzIn;K%G$AI)74eRc5Q3>ANDkqn*C1)9jOvVyPkq zOwZxuQOIQy`LyFQJ;{!H;tv33@9Ut(xU&l?0b*2-mK>;7?V?Q5XWkGH#HSka z8emzg3FJwb2B9(SYGiGYqqy@$bAYj0K=LCnf%Y~(gbPec6B8sCN3UKQAH8Hnf+}JJ z5|bvRNpOY0f~pupqN;Q)HVKKbD|(gC3$2;4>4}NytAq~?oJvgG;pkWUx}x|Dyv8u# zRVjLPdKUQoQVjg3!L;$}xp?$d6<6Y^<|vW3K2)N$L&ef`iPxlI(ta&>eiOUC}G!(pAXriC%ejYGgv1x|qB)azR1{39@m|=%_XW zKrr~wSS7&LF)sNzDK-{0Xt_4RdB}!~)P&Kk17>i^?>XfVpe|RM*9eZsZBh zU@62RUqi7Eh&==XY7>g;m0-_GQt5p%6C6~6gK5`b{<dkv-nRqF`o1KU)Ir)c64kOpGdP0`1OFW|L~T##M!YWAj#G@-hFk;iL=eR=$4x< zdxU7)isPPNX&f-}vCKSaVT`-Sd6<~}!#d+m7c_Riff)1#{T9wqDwhQ6=~lhzR=s{J zO1Gk>ThT3Fx$}g`MFBe767td>hjZ_ike~2rpivpczkYbsvj@vKDJ%|HwNUwj0oZ>J z7;>NFhZtaI=>xc0NGR|Z5(?r08>1gYxuG~L7Q@By5LlJsTrnXB12?2<2oV~5esqFx z2MO*ySo@5(@N(`wAv2mvR6OTl4V8B|pQ0ph=@6Y{t91$|q$yv|8MNo#Wp=zT#5a4yJ3*EOjh(+`qh0TfY>)`=-*+o2l(nYWqHR zibVsT2uQ%4xChus4Yh|KfwPl|9W(KCgaud!SlhQ}=bop0l-$98;wMPV?A*?@W9Me^ z;WYb@`ZeTiSsr*7>Q|hsUie^~on*nxIdOsgFeMEt7xs7zS~W3f)g}h5`q3CP44uPi zPdL(Kou#?V7nezM%RulZPORuNsRI?9Rv;1c!*W&6F()MT9LTK$p?MyXfK7qMZ zD-NP#Y;1%WI{9nJnxX86kjx7o`KvZOzMGf+tbF_G;*}pd?)lQ;{!F-ELFyS$JOitz z@1Fv+2A&cDu1IOJo)Y=eMhW^s$_hFjRAXGMEHHBtallE$DY-7|%h);&!yf3(UC^6* zp(O_sZ0NImdipmR-XoQfUF5c_bc; z7r~(i8uiO%T1lp7UC@)ehRa#5@I!Kut4~F|RH`go4lDvia2$ysSJn2lskzFKs~Rpp z)F! z6Z-7Da1*lC(HMur&~P-q1L@8<9E66O;`M;_hFrm~d&ZX#QljC^Ua6OK0@QaneO( zd(Xw>Cm7vmEiyGO%q5Hi9AQKU=8{oT>SKsT5*QpVUFG~n_Mka;u{HCkzPL~qb(H$~ z;5d@cV>g*|!HPmpvLh~>6PQLX zubsxt|AaDCWBf1oaWqy$yO*O3GHJ-PMCB{=s0h~s_9Vk|IX63J%Z%X=4U4>m zys6f_)mgBnkC8X=1SX83l%mqxhk=Q`!JEId)N;2+sqJ3#cBj4F8=_0P`485 zPJ7rV)SdN|tb3}~Jh*q)pYb#)o~AVWWc{JrjY}u*4k>kgYyQ5pzmFF%oeA|Qp`Nsd zeL_8^0+j0ZjHg5KbfnqmqtB|)oZnQ;-C)MkqIg=;o|ccAc4d7kWt28V+R|jwlUBz@s%E1B=#tY)f#up$|op@ip8EFO)7kvsAH3{cBzGv}Ob0 z#ad+c)+4i*^rl+k)*j>ZDvHsxc&NE>fnhU}bKVv%3O8VUU&eX<0q?auc$_0=OEDJ1 z{G6H^1vHY0SjEjyw?SnVizdlckW3L1a4UqT;#icsnrOxgvzHsgzvSkoaSU8Nk!Rna zqD&>J1}S(W;SA12QVixPM$}brPIzQ$MUwMnLT!nr!-J;uI=RS2FUr%iGu+n`c9(Iu zSYDkX6O~aGY%)`1>g}GEvG$^EWS-F$i>ggT56O^Ei<1?~a-vYm$}Bmm)mUF-(HE#J zDSmQ%l=TX}07by z1tnmKRSU%>^QW^e-@2<|&DF3pyF8Y0bt|s!w5vNCES-Oz6Ut&86^y`xQ?`JVP9XmY z%C#DE%mm30##ACypVUd2AXtcn&A?+NEjJuf6DU`QS4$n#7Ivt0G$%vam50}XQuNkX zkud7p`>@vCa6a^|D_Lr^0KP5Y%S<(49*rFfHylY_Y=JmxkX966u>LpPC_xvZKwjz@ zY4P0f&Up)yC%eIH&OPTL9q#Q03o&%exvf3d*so$6$>C(iH6>kxX9&vS^o7|Jh`zv> z22(G5QQ?D_$OTQX&go70$e27mX%H6%QH%#`_i<4FCKbvEOevFj&1&RakX>SkBasV{ zTOk8~Vz}Yqil6cgJkC7kR3fDEvfskx)UN!sQJ^&yv_d{Yeny0RMS0X`$2%vN8!~l$ zpeLD7zY^-7kI%;oQJK1T&M)_5>id=Y{?)z8o|Bo-DJ67@a(`6Yu=I^gEod5uPt`_w z^-@EoyiF-@n}2g7v?E>Dz4F%mlgiQOGoc|RG&CRIC~s8CJ2U0oN_qGE^U&g;ZXyuy zf_2$)K@kcv8nNi6|V-fdrZV zy%n*wwfDeYOVgNh5f)O7w~QY#8*L$;TIV#8+WA@2p4d%O&fRF>B)lGMSqrxOmGdhd znO*yoUHdY@{Yr3u+O?mf!?505F5p=Pj~I8t&kuUQ@_i0`xMxmWrswhwGk!5tFBgTR z&ej84=;DHP4O^!HLc(tXd+0Ub9Np%Jft6I;x%cz6CoWtRAFD;NKbd+ye=QgteS{xA zs>rE}cfiQC3?n-r9y*tGEUt|+ecKlc82udvqaU>6+iQ4)xdr3`oRuP#i%e=mj(=%F zIc}{j#8Yz^)y#_Dq0DoXFg!^NA>m2x?(nPlmFmo2ECpl&UTP;T8wV2T%=5VPtfyk# zQ@7@+d*|i5udNLHP`-C9Q-4sYKbY|xQap#!?30f*t&g2t+5e+~4-RJPk1F*?GoE9L z=UAG3HZ7$W&ZZZ&5h%Ldy;Q#BS#~MW{!Cy%2@DVey_XnhcnKictf^OOx}lGD?_kNT z(1N&dnDO$RYo5+*poSjUK@T88rW{fF`CI>X9$ePGiRIX~Kw za8zxP&ZQF2Xt;PD=Qdny!P)*j4ffPx zS|v7)4f;25Vnx3edS$W>Sdk1!N0I_w9zA=csF6FX&ee5G=ho}H*Xp~|J%gG06H5Jwbmdd# z%>0}IuMCb)kYkiwKmrN=B7M2!|AQ{wMUt!1lm|L^GH6L~^*hbS8`?^he+|D<^$+da zS?UMz2k6IUL+8q%($JH3?Z7@7y))*vW)BQuIpWsGthbGi!$@(f5V~Vdj z>kED-^zG2%qS?-72!ZO_DK*xPg9#pT7d5=_=pTE8AsC&AB@)=Mna@*v6~n@mk=oJ$ zh#p%?elK#T1}C=9L#C)Ee~Q#fQyq`qz+q~QY28o*mXpfvqxjC098-`6k14@pY1c8< zmZA1$oBf$Na_3yq)a+!=dv+>$3MbrJ(@8@}i~#rkp;~h8lVmrSiKmoFf|SX=lyjXU zHJ0J>Ez=z~1Edg^f0+`(vP>5u|2aDQE+xc>bByvA>5TXyDC218C#B@2={O@O@^8|; z-=gGqC?TB(AJeA{Nv0rpKU@NK4~#o;`# zQoeHG-uOM}=8?HDkVZg8yrzE8vY_KZ7651W z4T(E1l79%c;l|x@4_t4&aBJ~k8!k5C!Z`*GL4=qHwen-w2k&{mp+85yVBF+xgc8V` zzKE3&LJ38AB^29BfJ2f?+iZ)nbpvjJvco956fuM>x!CjMSQLrZ$4l{Dh9}D{c`PnZ zXcajaRUlWTxi1C>j~wJ&gh)nGwc%Q19SO)3<+xj`jf}lcX@?Q7$n0>1zE)rxGh!A! zv}|i1?KPRp_P~Y7J>2jMmPan0?7%*YDwG%fjPmR)vWz$SZj@(^E8~h+qwJk_2#}hj zX4B}qSX`wRjA%1vq82l;i@{SXe(lJEr`_>7aI+1}%8_awG{tQp1&m>BjH*#QD8>kx z0^zHq$qU>d;?sV9+G(<~^q7n+U!^E4Jw!peG!y7+7KawBFwUP71tLk*9Hhj$7hBzZ z?E_VO3dS53N++sYR@H4BC#PZhOtM%x(K!4bN0SrRyI?hiEhZM{mnCchlw>WmiR#w~ zqbdc^;ktT=Yol;lh)pF~zoCxO_CpuSEh7xe2!c@PgTqx^X3H{--eG21FpsV&ii>J+ z)y9+da!kf%qRb*8#V@kZ7ufu43Z=uI7=otjcQJ7C-$0_-X3171|4k$o!z`oVsS?e! zt?kYr0b30z0rP8wMMB9|)Wc_LXdZtYLq<`UMH0tDeAWo{Sk+p9EHTw`9@b{+GyxPT z@@R~-b8Qnh1X9*^8-L3{D^;X5luD7Qk;IHJ3s_N!5j7=>BWdEUY&pF|2fC6t_0iDu68#5mZ5hcx2a=hffo8ey=zIzryf>w0h^$% z^D^QHjgRr!VKc>ym@I}$^f~zU8i3jZz`BZ}_-rw!9ov-Uyx8YLme5pK>!Nz$E$Se# za@dXuQ(xs`{LbaS(vu1HE5ZJ>tDkp|aDaA@je~XXK2EHV^=R`3&86d)@qbP*`}fBk z%T#LKVOFWyA;5A^uxbOWLu$SEO%Vh3w^7wls^&C7C&Nf%)3a0Y!`Li9E*x*&fZHR% zj{J9!1Fzz$wwwu~h;l z41j<0D$|QFd%f68vE2Jy@XzMPDgPr>GnA@)ejI`5*iQT8CDQ_M$^RJl<-bqKSxU@6 zuTP8>=&P(#VnBXl;knT>e3vHa8PviiX+0EO3q_YpGofZB)Xac?{u>)bmFcS9^{V}A zRr`OGxZnTVS3bD%dx;PC{_gc3U(ZyXRjSTrik?x5o=JP2*~Zv|IymmyhAQqJG>AGQ zBTU_X50o2r33mT1`IX(1EEHxvl$D47zdA zHqmp!xCo7=$v#yWky=@OEbKQ-W@lPgd586h8?HImBbGeCN+#b`&zxwce?~9n5T-9K z-m`fOE|X>>?|`Eb#=|fU97T@Xdv3!K0F=_1GEW-e6W|*76;4c8=Pc248g_(o(9m%N zc9gIqAD%$!u6~cIM|lxXbB4hUhmpxdE^LM7M9u^6K!oYh{9EVIlAQQvPP~$Vj18F7h62$xla$v(wapd;+lH=Co?a8I{=};5aLrM9a z;`Nf2wUU-hNvl%Qx+S>X&q7?M^p*>j=>BX|^Lo>swWd9prhQ7&zO<(?Ti(51zGtm` zPo{jIQoaw!&r6R`daHGzY2gB#^3RITiVL1Uo4kE}Y3EYe>b`XNP$qmx2_ISyKe-lu z^210bd{POY%=k{FeW$j(tOTD>6j43@3Nu1O=!F%W&5TfCyk`YID*!MCsZFzV6CE)Y zt8SUrRvwt6oF8mD;#el6B%fZ_o6_|}p4v8OaE^&onKmt^i@!lrLJBZUNw9qJ{GHd= zE85m7+A+&svp{pW$+nNoFjte_n0d;phuI949MJ3(`rXz*?PJf;!@wU}i#= z`nt0}N5uL)QQky%vHZ$LXhHEjYFy6*OdE_Zrz{q}obgZAyjK3ndXMnK15$#5k*yVq z8kdqQUrq<%inUJ(?n}G&add0uF$vkQ$!Nui9X`prj_JUC!{d^sBPO99%t$n59V0#? zoS{+0OF}`wGVpnZXM`SV7G4+NFPX??p3(x7G{Xf_8e4|aj17WR!W+pqH2>r;Xf`c# zM}E-B820S|7`XSO24EycFwwD#9Gzj6y+V1=Ayq^+d4o($^E-T45V>|7T786rj52Om zjdaFtl2}7(0mg*$e#!p^(*mH0x^>0uP=f~JnFRS3Nf7=SSDHxz^bilP7Exgm5~1} z4L1#jCgU+GV9jKu0@izfgz=XD9iHcl);z2pTxICK8hN%4L|_)m%v`mV1+nC^Z$Ln; z@9Q;RB*zz%kJvIzZ{A^Lcfg$=Vm!2yaovTvV!QC|-fn*z9EF5A_rskh;2w6E3x%nF zaYmYek&zKJHPiveMck`_&=-m6L=F(`oWkldJY&|T8axvxVFr7~N@c+jnMw6qNJgb=WMYq5~}2@GPPVlXBCzzDFf0q+zje zsWiCEXXzkLRF9=9gcB2%@jL&p%i!=+s%HGCD@sZukx9znX_oM4PP}G@W%<`}E!Ck> zA_@uy$~kh04Y$eGt$Q8)4?-R-cxYZommlJvm7R;Py#3m*zP8+-so14d?9y*AfQMkB zOX^Hb%jHys8dGeoFz3OB4Cv6|L8t3Qxu#_ux#cW9;)i;ER!$(YbAnb6F>{&|%+K-L zP`Fr&{x8(-UqJ?FgCjz?@;g)Co_gopa{0S8-+%q?>pwiVTKnF}y^&1uF{SueI&{3C zIOA}`1c^iyW+G+lk*2js6S^JQr9^hkk7qqa-E(?6Zv; zfTo;b0e$)><#DC*tX@85kWny zDqOk2T)(mX0KO@O>!TjA1NHTM&S(9`vz8}GgTs!?LG275Me8h6;*_vG%VePCTO7vd>_tk3DIvX_Oh`~uE3=%F zbaeuW<|?Z0YWN?xl#1kc%}w1Pkpn79*x{bdQu2|v6pk$I@4TMzb}3#sF8kb_8<8d@ z(xpWD=f~&AH~eK31OA;~%J_Q~e=j^y+@0C->N{tZ@^s9J{*X#DJ)$PmF9Z>2Hz%`2AgA4S{odQXu9ogO%&pRS>Im={!1J| z>*n>ptfMpS=tR)^rC_?ag$GV*Si*pYoDJEU(-bWUL6NKpkCx{DD|_!@!Qte7 z4g0buC0=I)ajFpQ2`4LOR$o=R;G=_cdlGTOqbT|AbwCcho?c~N_N3V1tfdB-6Q3Uy zT%{*h+ta&`D2+!c5I>5yCY26n*V3Wo^Y6Z@?7%9)FMCoXcI`-a9{m?y+RNTWd)d3{ zx7;D;ej4rlo8hPuZsWgwf-k&4I?SfKz3Fhf;_cY-Yx!7ap0pGy&>L|cBzLcan}`BX zL*8Epp%*RLNrkg}0l^oz`Ab_tEg#Fwlb(VFmQjI7)B@?JULgI@3e@tk%seSr;4jkU zhLARiI9dJ5{m~Czqp`qjSd#{Vt%MD>_)iXd5$EA_E#MO0Mxq0s1o=W*Av ziJ8!HV=uaSSGdgpWbKvp*EdD;9`QK~J79!-oe2&5tSf-&*L zDv^Qj?zW&MHl3(+!}VMy~D7IlmZ?e+6aBe;P^71Gh1}OpP&s z@*namvts}pn`t|@t59F^oW~$zy8zX>VvEHu0b!Zk%m}<>eZa`CK~5x@wNdsHFG0(j ziP4%M?Zs7@k~$iZ08XHYZV<42zpNuu)}@qn&G%%BBlGPW#pR1%di$%t`ql5gcIUNB zakEm~Oun|*f{Nsqn~6%U0|eGruI-dTI>uD7jtyj zmC2Wq11_RQ3Gu%Ic-RKJ`0Onf<4g{iid%Pbql+Ttg$ntq{FW=EydvRZAMDw3}5#%NI1_o)6lMd*^NIwI$A(Wi3 zb8$k1H}9}$KIB($WqKT;daw%vVSS|0H$j>3a);4ag1@?l8}=?qF`4N{ zbcRhtXc&3xi4I0BeNtYHW<;N9-!mYIQIGtwd?RRPU3|ST-A!IVMtON zm|)nGzwwfHJf^;vh-zoHB9z6-rVB*vaxM%?Jm;AlkH_(5NEb|t6JJmOc`aBW-b&?$ zG8MIWN7QQEBUI0O>I;Y>v?9tVDMw-x`pI&fsd~tm9SGCc7^YwKf&8Hi&xnDjFl#p)|hor~>VH zq{_(pkyjawN~+vi86U&fs5-(Pk69hPXf>}o8h?A=Xr^M8zTu^H1(gvrlw6!=TA%0X zCrzLF!V9WUpJlsXUrKi=Ig6?ie;-fH>wjAguhTVIdP5sFeTU%!(UFfI0WiP_q9Zqi0Fd$=8EdYOH03`R6@!{5M(1U<*nB9RzAE) zy&?{O)U#g@8GBS@Om9ow4jHp|ral5P4zMyj1(9)3dyDGGQxuG82#A8Paq4?OU6v~< zE6uzhSoAh>Csh#@bMj$0k7Md2#02uVU~$x-c&E^ECVvD(r!ZrFLQG$lfRPEwK~^BV z&zue7(pBmD2#YKPazNQ+EAse)QI||&l3a_V5LT+Pj)#txCTX?opeL&d`lCobC+?B! z>FI))r_5FW?C3=GE3OsL6(+&J7(&vHQ`^xCVMaX~*w z-f`mp!6p(Ln{J7xZez}2Gsgx2Q%Q)v?m`jjVar`^X5wM7p&Hp?Wy4XTZrMT9aIvsL z$E=W<`&`NvSYlFL&#USOj+?;6>5P!(@6*P_d@>h9mb{8~a+3z5W)b>Hbcq>q!6<>h zgH^c=BaxudP+i~4M`C!3q$Y4`&MKYAF;757rVZ@n2fH7HWy0O1xS>rIomJUjW!hEA zua9Kh-HIFHqv-6~2-Yk?EHHc+H++$F+2DHFnYFSr8=k50_qyeO+ zKRAp;53B<>m0ONwT**HQyQwg47%r6_u9#HydpU7jrj@Je4|u>FZs6U`*ddaHnPLks z{l#WcBPxULsd9&qPBY5*W4Nl+4)V`R|5E?29sG+2mj}OpvrGq|I7 zvEk9XXso6=k=eByHM_P${--FD`NL>k_E7-rL=xM%x(Krpy@*C1r{{lxdj13+N6&!? zD~}mTwzi2P0@bWn^{rL)tyX8M4lDF|;=}56CFwHFnc1jV288q|Y|c;M#2!k?LK!XU zE-NBft(`N6Qi5X=qc=oq-ct?y%z!&!jDl(f|8-@z;`vmJTE>+mBKRwz0!^XDexdv#hSlG@bMO~EHU2x374+LqcbRkGfrHgL+1P??2hw};6xFhQoC)O%XP!KTR zw|y#*Hx1wf5XM_Wo$}{?6YX`(TXrrXHt?j-$%p6F3V}JP`L2TOJK&vOD`3 z0HM-SsQ{59{KOYiK@)DI~218e?)wD#HXS1w*(?p(Q^@gGwB zhtiG$udpVJN|^q^H6&J9#3nyrg85+%yJux3AY~Z29C7Ebn2g012?RA-7zbBsC!Dnz zomUm9%sxomfTjfLF>HfFa9LtXE4Davy%w~FjAR1{HIbmU0IeS3HOMEBZ-tqDw-{qSYv|j5g#S);xADncVtPdzn{%zNMUS64&%=CUH~T2@7E)@tZLb ziabI?1(#3_|9(~qCnZ( z>oSiDwoG>ybtmVOj$*Q3n3y~Q^?mU8v+ zv0Nn^4?EyXRd5`vBR=P$d9YrCCP1HGK$~7q;d!{>gOu;fn~l%g4mKO@?$p0jG51W2 z-V(!uxF8M&?5W&h%;G_+H(+>}k*AZ>-0A3?Bt)IX=JLC`h|-h>T~36&?e z+GiMz&v`C_>f1EkQ4}h z>~)qFkuXs-e-Xk2+W6{HaQSQLs)5zPOlZFn+K<<<6jv-J7e>;ax{pHP+pngpx>jQM z#uP+x4xLa!Cq8^O9eN?{DZmXAlF^w%J2z}4Ryy0I6FDFzfCWF`B5JZ+7^JuGM)V=D z#IF);y@;oCLDTv~MSYb%j5+sP~5qIxWNh7#p}B50E|=gyg$TTN_QVX%Rs^7&<0`(he|H z0Ikq4p*22`-FF$HWThrbiVP5dz?kgFuviuaTAd!x07z7&HE)#`WyEB((51x+pqRt~ zl);qNBw#t!L=_gHz4!Wl0uX^mQDRXuqFz;Vl=mH1HvfSHN3E0Ol~Nbg#tHxuc{4l>X%k^zx`e6h+^8SuQt1%<|# zAxtu!fCeGe4nwFf9B2p5s0ej3=R~j*K>Pay@$6sUcbnb5G$Uk@IO~VlJCdl}zsg z%9XNraGCQ%O+X||9;4yCNXaE6)`gy{&UOr7jVm-{V*CVzN z+n+mCQfrMULsJzsQz@ccJu(H<;Zj6^OAj=22Ft=F0VYw943P`*LgDpH?zcJ4vZ&WM zjV#ErkOmYGMV|YN-^hjH@D!6|H5yT_ggfe+ECMu5L}ob~KS>W#vYTLvwmtxB6bXUw z?K;{7*R~OC&iX>P1B$OETTy#wSgGjB7MI<=e0w@uTCJ3}Zs9!`ZBXVEwb9Ek5cu2W zZUan~9om4qeMQEzTk-5(A^(c90We5q1M}URDCTR+JMWBu+u!i-%zDbQWz9-iUlwn< zedhKH*>I&2Zb7kuR?-o*l8y*rz+BE&PR2}gDo|JG_k$BiY%J<$a#Q)f>fopT97Y0@ zjOnz^G5lTlUKhqr4SxO@NRLOtY;vv{%HY(e$4b&E9Yv>Mc6vm+S8*FBx1SLN%f>EQ zJO950yTkyMZCkYh@W02@%aj}gBrsb&_L;N_<-7Cvp%1+o{~5)9Cha)$5RQm+|K!ZW z97~`Ix2TeB$mp@TqW{bgiDi`k3vjBz@M6J2*-xgG0NKo4JK zFi!*)n&Mq&o){+7C@NRH%O)0zx9?#r)CGbi@=0L|ca4`}j8q;EwR#}qKce_ya0UGmCN}4I@q=ZRaL@Sw6jkC5)nZWc2 zB#OxbB_btE%3{1NBiY2qF%rqp44*vmhu||@iL+0AM<=J_vlG&Bnc`04BP9MM^e0dc zh~ieAPXrH6NjWHBT8QvZRtV0szpU`(^y7Xuh0e6~lNB1$)=yUGOu0NIrwHPopDe!eF|=K46Z8F-9fw71(8^D!bX@(!xRnx} z;#1=Cg?F)upUzgX3xU21q(wq;+2Zl#Csv2jWl(E9H6PgW3_3(?bts%J%6H!SnBDut zoa2luVb6IVI>M|Lf4*rntda-fr#J{!dD?bSv>tQ zUH!zkb-0Rh<|j1RxcZ54Yp+MFUA|0})ULkHzU)b)PCUG{_wJ!B0VgX}>}wUip%i9rSH|}`FDXd;1G{xy+sSJEe)kZ9g4Se%Sl)IjjVG&W9hy#O50g>$(~g2Lfy8)*`>aHs~0{P|1k3VHR=7&D1FavIqCA2ke9v) z_&Y7GOcK1E7C&_Sewd#XNFzeW3u3zcu+lyRv8%%aUbn-u<*gLy&5)F4Arw35WE1a) zy}spUcc4~oR6?z4zJLI=_sCYzc!)o$ruz$4;VvWOf`nCo@6AvH70Ir{rzNy-9!g=y z-EcazTXDB-`L+D)A)f3NI(xItyR)qwTV9Vy!eU-pxn1R8&CA5rY_NLEiEpFwR&uDJ z?w~@Rhh^l+K*7?y<*aT>S!d`(9W8c8A1(Gv8!h%j$V&@X;Vz}oLc%J*x7J{~UTCzm M{OloH!tnq90^Rqj%m4rY literal 0 HcmV?d00001 diff --git a/struct2tensor/benchmarks/ops_benchmark.py b/struct2tensor/benchmarks/ops_benchmark.py index dfe8da1..045ad3a 100644 --- a/struct2tensor/benchmarks/ops_benchmark.py +++ b/struct2tensor/benchmarks/ops_benchmark.py @@ -14,7 +14,6 @@ # pylint: disable=line-too-long r"""Benchmarks for struct2tensor. - Usage: blaze run -c opt --dynamic_mode=off \ --run_under='perflab \ @@ -31,10 +30,11 @@ """ # pylint: disable=line-too-long +import tensorflow as tf from absl.testing import parameterized + from struct2tensor.benchmarks import struct2tensor_benchmark_util from struct2tensor.ops import struct2tensor_ops -import tensorflow as tf class EquiJoinIndicesBenchmarks(struct2tensor_benchmark_util.OpsBenchmarks): diff --git a/struct2tensor/benchmarks/struct2tensor_benchmark.py b/struct2tensor/benchmarks/struct2tensor_benchmark.py index 7a221a0..639434a 100644 --- a/struct2tensor/benchmarks/struct2tensor_benchmark.py +++ b/struct2tensor/benchmarks/struct2tensor_benchmark.py @@ -16,11 +16,11 @@ """ +import tensorflow as tf from absl.testing import parameterized + import struct2tensor as s2t -from struct2tensor.benchmarks import benchmark_pb2 -from struct2tensor.benchmarks import struct2tensor_benchmark_util -import tensorflow as tf +from struct2tensor.benchmarks import benchmark_pb2, struct2tensor_benchmark_util class ProjectBenchmarks(struct2tensor_benchmark_util.ProtoDataBenchmarks): @@ -516,12 +516,12 @@ class PrensorToTensorBenchmarks( # pylint: disable=g-complex-comprehension @parameterized.named_parameters(*[ dict( - testcase_name="flat_to_dense_{}_features".format(n), - fn_name="flat_to_dense_{}_features".format(n), + testcase_name=f"flat_to_dense_{n}_features", + fn_name=f"flat_to_dense_{n}_features", fn_args=[ benchmark_pb2.FlatProto.DESCRIPTOR, [ - s2t.path.Path(["int_values_{}".format(i + 1)]) + s2t.path.Path([f"int_values_{i + 1}"]) for i in range(n) ] ], @@ -544,7 +544,7 @@ class PrensorToTensorBenchmarks( fn_args=[ benchmark_pb2.FlatProto100.DESCRIPTOR, [ - s2t.path.Path(["int_values_{}".format(i)]) + s2t.path.Path([f"int_values_{i}"]) for i in range(1, 101) ] ], @@ -555,7 +555,7 @@ class PrensorToTensorBenchmarks( fn_args=[ benchmark_pb2.FlatProto100.DESCRIPTOR, [ - s2t.path.Path(["int_values_{}".format(i)]) + s2t.path.Path([f"int_values_{i}"]) for i in range(1, 101) ] ], @@ -617,12 +617,12 @@ def test_to_dense(self, fn_name, fn_args, proto_list_key): # pylint: disable=g-complex-comprehension @parameterized.named_parameters(*[ dict( - testcase_name="flat_to_ragged_{}_features".format(n), - fn_name="flat_to_ragged_{}_features".format(n), + testcase_name=f"flat_to_ragged_{n}_features", + fn_name=f"flat_to_ragged_{n}_features", fn_args=[ benchmark_pb2.FlatProto.DESCRIPTOR, [ - s2t.path.Path(["int_values_{}".format(i + 1)]) + s2t.path.Path([f"int_values_{i + 1}"]) for i in range(n) ] ], @@ -645,7 +645,7 @@ def test_to_dense(self, fn_name, fn_args, proto_list_key): fn_args=[ benchmark_pb2.FlatProto100.DESCRIPTOR, [ - s2t.path.Path(["int_values_{}".format(i)]) + s2t.path.Path([f"int_values_{i}"]) for i in range(1, 101) ] ], @@ -656,7 +656,7 @@ def test_to_dense(self, fn_name, fn_args, proto_list_key): fn_args=[ benchmark_pb2.FlatProto100.DESCRIPTOR, [ - s2t.path.Path(["int_values_{}".format(i)]) + s2t.path.Path([f"int_values_{i}"]) for i in range(1, 101) ] ], @@ -718,12 +718,12 @@ def test_to_ragged(self, fn_name, fn_args, proto_list_key): # pylint: disable=g-complex-comprehension @parameterized.named_parameters(*[ dict( - testcase_name="flat_to_sparse_{}_features".format(n), - fn_name="flat_to_sparse_{}_features".format(n), + testcase_name=f"flat_to_sparse_{n}_features", + fn_name=f"flat_to_sparse_{n}_features", fn_args=[ benchmark_pb2.FlatProto.DESCRIPTOR, [ - s2t.path.Path(["int_values_{}".format(i + 1)]) + s2t.path.Path([f"int_values_{i + 1}"]) for i in range(n) ] ], @@ -746,7 +746,7 @@ def test_to_ragged(self, fn_name, fn_args, proto_list_key): fn_args=[ benchmark_pb2.FlatProto100.DESCRIPTOR, [ - s2t.path.Path(["int_values_{}".format(i)]) + s2t.path.Path([f"int_values_{i}"]) for i in range(1, 101) ] ], @@ -757,7 +757,7 @@ def test_to_ragged(self, fn_name, fn_args, proto_list_key): fn_args=[ benchmark_pb2.FlatProto100.DESCRIPTOR, [ - s2t.path.Path(["int_values_{}".format(i)]) + s2t.path.Path([f"int_values_{i}"]) for i in range(1, 101) ] ], @@ -823,8 +823,8 @@ class TfExampleBenchmarks(struct2tensor_benchmark_util.ProtoDataBenchmarks): # pylint: disable=g-complex-comprehension @parameterized.named_parameters(*[ dict( - testcase_name="to_fixed_len_feature_{}_features_1_value".format(n), - fn_name="tf_example_to_fixed_len_feature_{}".format(n), + testcase_name=f"to_fixed_len_feature_{n}_features_1_value", + fn_name=f"tf_example_to_fixed_len_feature_{n}", fn_args=[{ "int_values_1": tf.io.FixedLenFeature(shape=[1], dtype=tf.int64) }], @@ -842,7 +842,7 @@ class TfExampleBenchmarks(struct2tensor_benchmark_util.ProtoDataBenchmarks): testcase_name="to_fixed_len_feature_100_features_1_value", fn_name="tf_example_to_fixed_len_feature_100_features_1_value", fn_args=[{ - "int_values_{}".format(i): tf.io.FixedLenFeature( + f"int_values_{i}": tf.io.FixedLenFeature( shape=[1], dtype=tf.int64) for i in range(1, 101) }], proto_list_key="tf_examples_100_features"), @@ -850,7 +850,7 @@ class TfExampleBenchmarks(struct2tensor_benchmark_util.ProtoDataBenchmarks): testcase_name="to_fixed_len_feature_100_features_100_values", fn_name="tf_example_to_fixed_len_feature_100_features_100_values", fn_args=[{ - "int_values_{}".format(i): tf.io.FixedLenFeature( + f"int_values_{i}": tf.io.FixedLenFeature( shape=[100], dtype=tf.int64) for i in range(1, 101) }], proto_list_key="tf_examples_100_features_100_values") @@ -862,8 +862,8 @@ def test_parse_tf_example(self, fn_name, fn_args, proto_list_key): @parameterized.named_parameters(*[ dict( - testcase_name="to_var_len_feature_{}_features_1_value".format(n), - fn_name="tf_example_to_var_len_feature_{}".format(n), + testcase_name=f"to_var_len_feature_{n}_features_1_value", + fn_name=f"tf_example_to_var_len_feature_{n}", fn_args=[{ "int_values_1": tf.io.VarLenFeature(dtype=tf.int64) }], @@ -880,7 +880,7 @@ def test_parse_tf_example(self, fn_name, fn_args, proto_list_key): testcase_name="to_var_len_feature_100_features_1_value", fn_name="tf_example_to_var_len_feature_100_features_1_value", fn_args=[{ - "int_values_{}".format(i): tf.io.VarLenFeature(dtype=tf.int64) + f"int_values_{i}": tf.io.VarLenFeature(dtype=tf.int64) for i in range(1, 101) }], proto_list_key="tf_examples_100_features"), @@ -888,7 +888,7 @@ def test_parse_tf_example(self, fn_name, fn_args, proto_list_key): testcase_name="to_var_len_feature_100_features_100_values", fn_name="tf_example_to_var_len_feature_100_features_100_values", fn_args=[{ - "int_values_{}".format(i): tf.io.VarLenFeature(dtype=tf.int64) + f"int_values_{i}": tf.io.VarLenFeature(dtype=tf.int64) for i in range(1, 101) }], proto_list_key="tf_examples_100_features_100_values") @@ -899,8 +899,8 @@ def test_parse_tf_example_to_sparse(self, fn_name, fn_args, proto_list_key): @parameterized.named_parameters(*[ dict( - testcase_name="to_ragged_feature_{}_features_1_value".format(n), - fn_name="tf_example_to_ragged_feature_{}".format(n), + testcase_name=f"to_ragged_feature_{n}_features_1_value", + fn_name=f"tf_example_to_ragged_feature_{n}", fn_args=[{ "int_values_1": tf.io.RaggedFeature(value_key="int_values_1", dtype=tf.int64) @@ -919,8 +919,8 @@ def test_parse_tf_example_to_sparse(self, fn_name, fn_args, proto_list_key): testcase_name="to_ragged_feature_100_features_1_value", fn_name="tf_example_to_ragged_feature_100_features_1_value", fn_args=[{ - "int_values_{}".format(i): tf.io.RaggedFeature( - value_key="int_values_{}".format(i), dtype=tf.int64) + f"int_values_{i}": tf.io.RaggedFeature( + value_key=f"int_values_{i}", dtype=tf.int64) for i in range(1, 101) }], proto_list_key="tf_examples_100_features"), @@ -928,8 +928,8 @@ def test_parse_tf_example_to_sparse(self, fn_name, fn_args, proto_list_key): testcase_name="to_ragged_feature_100_features_100_values", fn_name="tf_example_to_ragged_feature_100_features_100_values", fn_args=[{ - "int_values_{}".format(i): tf.io.RaggedFeature( - value_key="int_values_{}".format(i), dtype=tf.int64) + f"int_values_{i}": tf.io.RaggedFeature( + value_key=f"int_values_{i}", dtype=tf.int64) for i in range(1, 101) }], proto_list_key="tf_examples_100_features_100_values") diff --git a/struct2tensor/benchmarks/struct2tensor_benchmark_util.py b/struct2tensor/benchmarks/struct2tensor_benchmark_util.py index 067dd99..1f57422 100644 --- a/struct2tensor/benchmarks/struct2tensor_benchmark_util.py +++ b/struct2tensor/benchmarks/struct2tensor_benchmark_util.py @@ -18,14 +18,14 @@ import statistics import timeit -from absl import flags -from absl.testing import parameterized import cpuinfo import psutil import tensorflow as tf - -from tensorflow.core.protobuf import rewriter_config_pb2 # pylint: disable=g-direct-tensorflow-import - +from absl import flags +from absl.testing import parameterized +from tensorflow.core.protobuf import ( + rewriter_config_pb2, # pylint: disable=g-direct-tensorflow-import +) FLAGS = flags.FLAGS diff --git a/struct2tensor/calculate.py b/struct2tensor/calculate.py index d7a32e5..b5001eb 100644 --- a/struct2tensor/calculate.py +++ b/struct2tensor/calculate.py @@ -45,12 +45,10 @@ from typing import Dict, List, Mapping, Optional, Sequence, Tuple -from struct2tensor import calculate_options -from struct2tensor import expression -from struct2tensor import path -from struct2tensor import prensor import tensorflow as tf +from struct2tensor import calculate_options, expression, path, prensor + # type(id(...)), disambiguated for clarity IDExpression = int IDNodeTensor = int @@ -152,7 +150,6 @@ def calculate_prensors( Returns: a list of prensors. """ - return calculate_prensors_with_graph( expressions, options=options, feed_dict=feed_dict)[0] @@ -191,6 +188,7 @@ def _get_earliest_equal_calculation(expr: expression.Expression): """Finds an expression with an equal value. Recursively traverses sources while expressions are the identity. + Args: expr: the expression to find an equal value. @@ -214,7 +212,7 @@ def _node_type_str(node_tensor: prensor.NodeTensor): return _fancy_type_str(node_tensor.is_repeated, None) -class _ExpressionNode(object): +class _ExpressionNode: """A node representing an expression in the ExpressionGraph.""" def __init__(self, expr: expression.Expression): @@ -223,7 +221,6 @@ def __init__(self, expr: expression.Expression): Args: expr: must be the result of _get_earliest_equal_calculation(...) """ - self.expression = expr self.sources = [ _get_earliest_equal_calculation(x) @@ -250,12 +247,8 @@ def __eq__(self, node: "_ExpressionNode") -> bool: return all([a is b for a, b in zip(self.sources, node.sources)]) def __str__(self) -> str: - return ("expression: {expression} sources: {sources} destinations: " - "{destinations} value: {value}").format( - expression=str(self.expression), - sources=str(self.sources), - destinations=str(self.destinations), - value=str(self.value)) + return (f"expression: {str(self.expression)} sources: {str(self.sources)} destinations: " + f"{str(self.destinations)} value: {str(self.value)}") def _create_value_error(self) -> ValueError: """Creates a ValueError, assuming there should be one for this node.""" @@ -295,7 +288,7 @@ def __hash__(self) -> int: return hash(tuple([id(x) for x in self.sources])) -class ExpressionGraph(object): +class ExpressionGraph: """A graph representing the computation of a list of expressions.""" def __init__(self): diff --git a/struct2tensor/calculate_options.py b/struct2tensor/calculate_options.py index f73a1eb..e8b311e 100644 --- a/struct2tensor/calculate_options.py +++ b/struct2tensor/calculate_options.py @@ -19,7 +19,7 @@ """ -class Options(object): +class Options: """Options for calculate functions. Do not construct Options directly. The preferred method of creating an object diff --git a/struct2tensor/calculate_test.py b/struct2tensor/calculate_test.py index 8a6c8e5..9ced3eb 100644 --- a/struct2tensor/calculate_test.py +++ b/struct2tensor/calculate_test.py @@ -15,19 +15,21 @@ import random -from struct2tensor import calculate -from struct2tensor import calculate_options -from struct2tensor import create_expression -from struct2tensor import expression_add -from struct2tensor import path -from struct2tensor.expression_impl import promote -from struct2tensor.expression_impl import proto_test_util -from struct2tensor.test import expression_test_util -from struct2tensor.test import prensor_test_util import tensorflow as tf - from absl.testing import absltest -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) + +from struct2tensor import ( + calculate, + calculate_options, + create_expression, + expression_add, + path, +) +from struct2tensor.expression_impl import promote, proto_test_util +from struct2tensor.test import expression_test_util, prensor_test_util def get_mock_linear_graph(length): @@ -89,7 +91,7 @@ def test_calculate_mock(self): self.assertIs(my_input._calculate_output, result) def test_calculate_broken_mock_is_repeated(self): - with self.assertRaisesRegexp( + with self.assertRaisesRegex( ValueError, "Expression Node0 returned the wrong type: expected: repeated actual: optional ."): @@ -98,7 +100,7 @@ def test_calculate_broken_mock_is_repeated(self): calculate.calculate_values([single_node]) def test_calculate_broken_mock_dtype(self): - with self.assertRaisesRegexp( + with self.assertRaisesRegex( ValueError, "Expression Node0 returned the " "wrong type: expected: repeated actual: repeated ."): diff --git a/struct2tensor/calculate_with_source_paths.py b/struct2tensor/calculate_with_source_paths.py index 7fa8129..02bf714 100644 --- a/struct2tensor/calculate_with_source_paths.py +++ b/struct2tensor/calculate_with_source_paths.py @@ -13,28 +13,25 @@ # limitations under the License. """Given an expression, calculate the prensor and source paths.""" -from typing import List, NamedTuple, Optional, Sequence, Set, Tuple +from typing import List, NamedTuple, Optional, Sequence, Tuple -from struct2tensor import calculate -from struct2tensor import calculate_options -from struct2tensor import expression -from struct2tensor import path -from struct2tensor import prensor +import tensorflow as tf +from google.protobuf import descriptor + +from struct2tensor import calculate, calculate_options, expression, path, prensor from struct2tensor.expression_impl import proto from struct2tensor.proto import query_metadata_pb2 -import tensorflow as tf -from google.protobuf import descriptor # A proto summary represents the tensor of protos and descriptor, along with # all the paths associated with it. # tensor is tf.Tensor # descriptor is descriptor.Descriptor # paths is List[path.Path] -ProtoRequirements = NamedTuple("ProtoRequirements", - [("tensor", tf.Tensor), - ("descriptor", descriptor.Descriptor), - ("paths", List[path.Path])]) +class ProtoRequirements(NamedTuple): + tensor: tf.Tensor + descriptor: descriptor.Descriptor + paths: List[path.Path] def calculate_prensors_with_source_paths( @@ -92,7 +89,6 @@ def requirements_to_metadata_proto( inp: Sequence[ProtoRequirements] result: a proto to be populated. """ - for x in inp: _requirement_to_parsed_proto_info(x, result.parsed_proto_info.add()) # pytype: disable=wrong-arg-types diff --git a/struct2tensor/calculate_with_source_paths_test.py b/struct2tensor/calculate_with_source_paths_test.py index fdeeb1e..09feaca 100644 --- a/struct2tensor/calculate_with_source_paths_test.py +++ b/struct2tensor/calculate_with_source_paths_test.py @@ -13,19 +13,17 @@ # limitations under the License. """Tests for calculate_with_source_paths.""" -from absl.testing import absltest -from struct2tensor import calculate_with_source_paths -from struct2tensor import path -from struct2tensor.expression_impl import promote -from struct2tensor.expression_impl import proto -from struct2tensor.expression_impl import proto_test_util -from struct2tensor.proto import query_metadata_pb2 -from struct2tensor.test import test_any_pb2 -from struct2tensor.test import test_pb2 import tensorflow as tf - +from absl.testing import absltest from google.protobuf import text_format -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) + +from struct2tensor import calculate_with_source_paths, path +from struct2tensor.expression_impl import promote, proto, proto_test_util +from struct2tensor.proto import query_metadata_pb2 +from struct2tensor.test import test_any_pb2, test_pb2 @test_util.run_all_in_graph_and_eager_modes @@ -38,8 +36,8 @@ def assertLen(self, arr, expected_len): # E.g.: ["a", "a", "b"] == ["a", "b"] def equal_ignore_order(self, a, b): debug_string = "[{}] vs [{}]".format( - ",".join(['"{}"'.format(str(ae)) for ae in a]), - ",".join(['"{}"'.format(str(be)) for be in b])) + ",".join([f'"{str(ae)}"' for ae in a]), + ",".join([f'"{str(be)}"' for be in b])) for ae in a: self.assertIn(ae, b, debug_string) for be in b: diff --git a/struct2tensor/create_expression.py b/struct2tensor/create_expression.py index 396955f..1c5436b 100644 --- a/struct2tensor/create_expression.py +++ b/struct2tensor/create_expression.py @@ -19,12 +19,10 @@ from typing import FrozenSet, Mapping, Optional, Sequence -from struct2tensor import calculate_options -from struct2tensor import expression -from struct2tensor import path -from struct2tensor import prensor import tensorflow as tf +from struct2tensor import calculate_options, expression, path, prensor + class _DirectExpression(expression.Expression): """An expression where the value is immediate. @@ -84,7 +82,7 @@ def known_field_names(self) -> FrozenSet[path.Step]: return frozenset(self._children.keys()) def __str__(self) -> str: - return "_DirectExpression: {}".format(str(id(self))) + return f"_DirectExpression: {str(id(self))}" def create_expression_from_prensor( diff --git a/struct2tensor/create_expression_test.py b/struct2tensor/create_expression_test.py index 3104bbc..9d08010 100644 --- a/struct2tensor/create_expression_test.py +++ b/struct2tensor/create_expression_test.py @@ -13,12 +13,11 @@ # limitations under the License. """Tests for struct2tensor.create_expression.""" -from absl.testing import absltest -from struct2tensor import create_expression -from struct2tensor import path -from struct2tensor.test import expression_test_util -from struct2tensor.test import prensor_test_util import tensorflow as tf +from absl.testing import absltest + +from struct2tensor import create_expression, path +from struct2tensor.test import expression_test_util, prensor_test_util class ExpressionTest(absltest.TestCase): diff --git a/struct2tensor/expression.py b/struct2tensor/expression.py index c3b7665..2d90790 100644 --- a/struct2tensor/expression.py +++ b/struct2tensor/expression.py @@ -32,14 +32,14 @@ import abc from typing import Callable, FrozenSet, List, Mapping, Optional, Sequence, Union -from struct2tensor import calculate_options -from struct2tensor import path -from struct2tensor import prensor import tensorflow as tf - -from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.util.lazy_loader import ( + LazyLoader, # pylint: disable=g-direct-tensorflow-import +) from tensorflow_metadata.proto.v0 import schema_pb2 +from struct2tensor import calculate_options, path, prensor + # The purpose of this type is to make it easy to write down paths as literals. # If we made it Text instead of str, then it wouldn't be easy anymore. CoercableToPath = path.CoercableToPath @@ -82,7 +82,7 @@ IndexValue = Union[int, tf.Tensor, tf.Variable] # pylint: disable=invalid-name -class Expression(object, metaclass=abc.ABCMeta): +class Expression(metaclass=abc.ABCMeta): """An expression represents the calculation of a prensor object.""" def __init__( @@ -118,7 +118,7 @@ def is_repeated(self) -> bool: @property def type(self) -> Optional[tf.DType]: - """dtype of the expression, or None if not a leaf expression.""" + """Dtype of the expression, or None if not a leaf expression.""" return self._type @property @@ -262,7 +262,7 @@ def get_child_or_error(self, field_name: path.Step) -> "Expression": """Gets a named child.""" result = self.get_child(field_name) if result is None: - raise KeyError("No such field: {}".format(field_name)) + raise KeyError(f"No such field: {field_name}") return result def get_descendant(self, p: path.Path) -> Optional["Expression"]: @@ -278,8 +278,7 @@ def get_descendant_or_error(self, p: path.Path) -> "Expression": """Finds the descendant at the path.""" result = self.get_descendant(p) if result is None: - raise ValueError("Missing path: {} in {}".format( - str(p), self.schema_string(limit=20))) + raise ValueError(f"Missing path: {str(p)} in {self.schema_string(limit=20)}") return result def get_known_children(self) -> Mapping[path.Step, "Expression"]: @@ -321,10 +320,10 @@ def _schema_string_helper(self, field_name: path.Step, """Helper for schema_string.""" repeated_as_string = "repeated" if self.is_repeated else "optional" if self.type is None: - result = ["{} {}:".format(repeated_as_string, str(field_name))] + result = [f"{repeated_as_string} {str(field_name)}:"] else: result = [ - "{} {} {}".format(repeated_as_string, str(self.type), str(field_name)) + f"{repeated_as_string} {str(self.type)} {str(field_name)}" ] if limit is not None and limit == 0: if self.get_known_children(): @@ -335,7 +334,7 @@ def _schema_string_helper(self, field_name: path.Step, for field_name, subexpression in self.get_known_children().items(): recursive = subexpression._schema_string_helper(field_name, new_limit) # pylint: disable=protected-access - result.extend([" {}".format(x) for x in recursive]) + result.extend([f" {x}" for x in recursive]) return result # Begin methods compatible with v1 API. ###################################### @@ -521,7 +520,6 @@ def create_proto_index(self, field_name: path.Step) -> "Expression": Returns: An Expression object representing the result of the operation. """ - return reroot.create_proto_index_field(self, field_name) def cogroup_by_index(self, source_path: CoercableToPath, left_name: path.Step, @@ -627,9 +625,10 @@ def __hash__(self) -> int: return id(self) def __eq__(self, expr: "Expression") -> bool: - """if hash(expr1) == hash(expr2): then expr1 == expr2. + """If hash(expr1) == hash(expr2): then expr1 == expr2. Do not override this method. + Args: expr: The expression to check equality against diff --git a/struct2tensor/expression_add.py b/struct2tensor/expression_add.py index 251079b..fd09fbe 100644 --- a/struct2tensor/expression_add.py +++ b/struct2tensor/expression_add.py @@ -23,9 +23,7 @@ from typing import FrozenSet, Mapping, Optional, Sequence, Tuple -from struct2tensor import expression -from struct2tensor import path -from struct2tensor import prensor +from struct2tensor import expression, path, prensor from struct2tensor.calculate_options import Options @@ -106,8 +104,7 @@ def known_field_names(self) -> FrozenSet[path.Step]: def __str__(self) -> str: keys_to_add = ",".join([str(k) for k in self._path_map.keys()]) - return "_AddPathsExpression({}, [{}])".format( - str(self._origin), keys_to_add) + return f"_AddPathsExpression({str(self._origin)}, [{keys_to_add}])" def add_paths(root: expression.Expression, @@ -130,9 +127,9 @@ def add_paths(root: expression.Expression, """ for p in path_map.keys(): if root.get_descendant(p.get_parent()) is None: - raise ValueError("No parent of {}".format(p)) + raise ValueError(f"No parent of {p}") if root.get_descendant(p) is not None: - raise ValueError("Path already set: {}".format(str(p))) + raise ValueError(f"Path already set: {str(p)}") _, map_of_maps = create_subtrees(path_map) return _AddPathsExpression(root, map_of_maps) @@ -187,9 +184,9 @@ def add_to(root: expression.Expression, if not _is_true_source_expression( root.get_descendant_or_error(path_parent), origin_root.get_descendant_or_error(path_parent)): - raise ValueError("Not a true source for tree with {}".format(str(p))) + raise ValueError(f"Not a true source for tree with {str(p)}") if root.get_descendant(p) is not None: - raise ValueError("Already contains {}.".format(str(p))) + raise ValueError(f"Already contains {str(p)}.") path_map = { p: origin_root.get_descendant_or_error(p) for p, origin_root in origins.items() diff --git a/struct2tensor/expression_add_test.py b/struct2tensor/expression_add_test.py index 2743de3..c41d52d 100644 --- a/struct2tensor/expression_add_test.py +++ b/struct2tensor/expression_add_test.py @@ -13,14 +13,11 @@ # limitations under the License. """Tests for struct2tensor.expression_add.""" +import tensorflow as tf from absl.testing import absltest -from struct2tensor import create_expression -from struct2tensor import expression_add -from struct2tensor import path -from struct2tensor.test import expression_test_util -from struct2tensor.test import prensor_test_util -import tensorflow as tf +from struct2tensor import create_expression, expression_add, path +from struct2tensor.test import expression_test_util, prensor_test_util class ModifyTest(absltest.TestCase): diff --git a/struct2tensor/expression_impl/__init__.py b/struct2tensor/expression_impl/__init__.py index acd29ae..f3c1cbc 100644 --- a/struct2tensor/expression_impl/__init__.py +++ b/struct2tensor/expression_impl/__init__.py @@ -24,20 +24,22 @@ ``` """ -from struct2tensor.expression_impl import apply_schema -from struct2tensor.expression_impl import broadcast -from struct2tensor.expression_impl import depth_limit -from struct2tensor.expression_impl import filter_expression -from struct2tensor.expression_impl import index -from struct2tensor.expression_impl import map_prensor -from struct2tensor.expression_impl import map_prensor_to_prensor -from struct2tensor.expression_impl import map_values -from struct2tensor.expression_impl import parquet -from struct2tensor.expression_impl import placeholder -from struct2tensor.expression_impl import project -from struct2tensor.expression_impl import promote -from struct2tensor.expression_impl import promote_and_broadcast -from struct2tensor.expression_impl import proto -from struct2tensor.expression_impl import reroot -from struct2tensor.expression_impl import size -from struct2tensor.expression_impl import slice_expression +from struct2tensor.expression_impl import ( + apply_schema, + broadcast, + depth_limit, + filter_expression, + index, + map_prensor, + map_prensor_to_prensor, + map_values, + parquet, + placeholder, + project, + promote, + promote_and_broadcast, + proto, + reroot, + size, + slice_expression, +) diff --git a/struct2tensor/expression_impl/__pycache__/__init__.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e13e04711aa005f11fedf257b3c28965899351d0 GIT binary patch literal 1314 zcmZvb%Wl*#6o#F-WhQfa z6$;(5VmDHEm5?|#MNOs?{qyTPImh;yljK{!?;*GjSAS*?T!em^;Qv*-g!3N&Ur~%O zik-N|owSB)j^R3R9oH>4fE&1Bxe46FP0KCd7H(N~fnDrcZUeV*+j0lEgFBWzU=Mqi zyTD!CwcG>l;hyC_a3A+A4}b@FV0j2U#6!y?;1M2K9s`f@*zyE;f+xUR@6q;+)1P~3 zCZzU5&iz!xOHP%~7Ct@BBvp!u1z~B%r=IsAq2&UdFy+q}r@l(WlE?l5^}{HFH8l1) zn^RxIf|=l49I?fxUC-O!-}hiIbC}X{8R;EO7fMJ!RKD8L-b{*gEpu7E){N&t`E>2( zG|V!7LR6H{G_+^^HqUR{IqQ`(5O)rw2;(SJ8iH<2Go295Qf9(2{3A^zx!N3VHrXPk z$EG+9GXirlT|;=Q6cQ~;8P8*x1C_7T9-m4vsE928l>%kj3dz%J3w3 zOVwP9EI1N!t}+-#Q0Z+gx_S=wH4F%03#JK~W+xZzm;c_^SMn|#)QcZZo7 znkGZJjG@3pAfRlKz(ivtd}v(-_Ct^-KjaVSOMzWt04@d)Dxko<38qOZi$n5 zS;+HCJiY^Yp%7RK6vQR55L^oK)F`pAll%};7?=4>6mkr5HPsF8GxfwKV9!uis(v|7N1grRk(hY4RV{B;YyiY!73S%{=OWB-~mnAim+0b+lNohmMD{0Jo zZTfj+SrHedAuu$BOBK;wUej&#e@7rrxZQ#_GTgf|I(Pt=`NTtL(kD+ZS~yvzp7x zSctT|#n?%g)>T--GK@r1X!*JKFa;AkYbctOJ|TRgSQC<3RHO{dYC~7BscJzr6zr3x zrsWKDI3X!UETe!~+SXm2jIb7|sBG35q%Ljd@_B45?UkZwiky_*$Vyx4hP0{W3`I)E zGB60WP&a0Rb&~`(Nk)|wtQ?yN!&J0Ft_bDS&0240YBokaP~S%0z-B#Er4_|CTSt$} z>CB@EtPfJ5CuO%@A`Lwx(~(s$Cp*pN@}w{_dS*^`t8Y&*CRW%=(S}I3q+~KlZ_YIA z-KnoCn&R|AJ2eVtY?5t?D~ZPDOtkS#vTr`)$K(gpDk0Ehl4hN$LBVx~638{jM7T z8@J{4i2@b;FS*#oM8JyP)6{LH_&_nN$lL4W@X~oJ@<93VhEmKZ3EpalYBL*oEIy^O zejSam+8lTjX%2jcp(XObL}1I7T3?>AL?X6Ac8e9#6rA#6MsU15eZjwT0sMXSdlVmY z2EwKLtaBH|C$J-Y z3QnJS+<5M3rxCBC@Ntm2Pr>aokK+UGxBOxvq!A9VLJwGcvLbiX^)0BjV1*sGW_6^B z6jSoKZ6)QlwWnDI+J7*2IAD&<{Z1}NUooSvl!RBl!VH|&fx>4Vgrk^w%wUT?=61M4G@N|o zhK<+G?r=5pS@X7+nSfw~YaDh27|jV2ICUPuZ(!Pg?eJ?JY617D$AO)|8W}ZvAcczr zxkpS2Z3u1Nh{HNb+%aU9{#keqC@=7UdpeETA}$@9P23Ux;%wfCQ)UytWj61yu-Sq; z!M`}0H{z7pf~U=9bh|6#I9WWbJA}3NdVMzTo1gE8K-oZ$!So)y!CUD^ilhLqNY3d) z9LZ-VGini_uZZ{@i%ArsS=_Lp2*N31H#tIrI4gN|8DYjod5EMql#JNdF6C8*43l8C zkk?g^gNUFRAOeL323%YbbT;WCfW|yyMZkz~8p&YKBVi~!CX;YN0zX8UO-0&707R&o zpn68#0O$i2Qwr+_riIf#w^8(*%{j)+S@Q&;BMqw~9Oo0<7G*(O(dWTKawVL1E5x>- z%DLqL$5#k4M8k(8+o5=}dG1D6|aUlT3~Y40ue zgLHmFd0o>~gee>N!=IBIVs2RLq1rGN5)~sTbUM+rdP&mGP$Nmeir1-k8TPizRfoAy z^Kr2IDA>KPmV;wvaIDG+f%BE-F0=WZ**p&S7(CC4J^p#9_n`gYUO6;shDJ;5QwcUX zqgTtpvu5yYm1_x%RQk@EeUnu#5Nu_|UUMbdV@3xKHb42`C^}h+PF5nVW@P9%GIkUh zJA7P@Ts0$COOdOUNXuSmFH{Y3v1B#Gof$HF;GjJ=u}ZY_>Dhz(pDZ3l$4k+1#-;x_ zGI|slJ$$nqnKC0&rN~qrm&@P+VR|puacQbVyG{7^l~0OC(TP%Yg0UGqj$Am3TsX{^ zBeP~?wiKDIV>9`u5Et(*bX@D&G-E@H_k|K!Wpo695VRkhq_C8Y2&*^!$50x!f2#TH(#T zNKuOuazGCJOt^&^;lIS~e|bkKsP6$kl7FBpBB6Ni3IvCd_KxW#2ZAc^=X3-2NE>e1 z*EVul%;UV|nCFV?g!pa8N3QT4hSMBxkij()KGN%l*`LL_*b!Vzgl+H>zGshi&f6V~ z(d?5t=&%h4e-I<$Mf zB7{Dg{BUwVTo!swp|>RTQp?nbQ~P6Oq0bchNH82#wo58vJYUO6~w21f~zj8!5{ zW~ASYoIAWi|4NYym1y(sWoA!Si>l>Ta)7}preGfup-9c_>i`uX5&jV0C!a=p(0XLy z38-I=M*P6G$gcq0`p^mBrw_0;;G9M{nh0#)UU1SS&VG{f_UB6^XRjn&_bGhfqKcrw zw0|ziw)MopA{X}+dBu;p6@ajJ#m6{0|13NW0s(W{1UeFeB_yYrzpUTZZ(=v|Fvg5p zSiik}E@h)jN3!Hi7BKbOsJJvcV5Of&0T;^g9i6q}fZr!4P~XHhEfLiGCZ{Ip`;kQ@`tqJ^ zXphVCa_~vhQcMnEd>y~q_^cp@k;RM45l}SMGL1C~WFSu&i)>`eTB|3CZ+tc)$HBJ^ zGq+<7%1m}3e^@62>k(NlvQutVy4E~|h@R8dB+6~z)bkShI?grOogC#EaO|z0p^id8 z?owRp+vDEnm-^w8JDG)TUtOrrLc-a4$v}c|WKl>>urwlb(l(AsS`pXBSS`%SxDHRZ zTAVANW2hO?)f>LOd(cz7{F3*_H@_D-tdEKw3wR*!7KFMN%Wws*yTN?U`}43A&33qy z7o~7p;!-3FdsvP(u>Ie62vuUqEDv_#TFasi5U~d@CCOERjeOm=x^$QgFC{z9c^b0G z9an{j;nSF1Cplkbokq8r2N473-KfwMUvX3b%}R@ZyD!6Ol@ zq;x}BPuXY@5zq?i8|#W@MN+9;F=wPwRy?PtU=O4f*%FN{^!^obh^E}NTCH}7@z$0I zS|VLav>6a;vsAE2qbbD2w63R8I_&~{mb7cb&g*#H?mww+SBjsZU;jH4e@67|=Yp}H z-hXocSnN6yyDE`3Gct(sX>;Ywh04$ca$Qj_*!?9J42Hg@qRNUUu4j1n?vusJKw|f< zDfX7whdC;Zesb^y3bH#EA?-kcyBZqoX9ZYBBR`PfwO8=$lYV8gccTrfa?pDJRC^748pR?Ymmt5N( zIJrcx^Vi^|(=`KeD8xm(_g5;tV`Z^>H&u!Em&ATMXppe(;|WgZ6hA!?j1(3hGFi1( z(N?eZGwwvl{-t2Gc1qRQ+D@DP06H7uz(0jaf&kkcOljXxt#X$fzYMNi^gy)A{YuZ6dLW+!fcLhuYL z_QI8D+ta~={-fwR+ykv-CW}%eCx3ZVh{c&|d=zFzeTFIf-Ir&I1$ofI-8^0Cv@b2`)W- zH>Nhpnf|mJm#OcV2z`3Tpp1_5G}?%}_fadLD~4?V?T?D;W)VliJ;ypbZ=f(V`y%^& zaI*qwS+=71Z3n-IU@UGGWFqvF1vzhx;;2T)(rQC6Rx>#Y?Y3PtICkytLG?2Cb3%?k zfKY!Q#V+@}xoyAe7sJQxSB~1Rl-p;`_StgtoY_3Ld#}1e9(*!mc&73ONjvjF67_L>ife6}EqG(zW~aCN8ULULy~${S6g#?XkoM z09hJ4v+qzNxf`ZitJ#4>86}@jrGO}Iql*Q#S8ZfOE}dcY52@o z#{Ml)4Nz_R`WOA?jYV_%ZT4Q}&{$V=g#$x}eb44TzhRDCt#Ww#4h1tpwp;k^FiQV< zgW=sok48T_;ao0ftPmhH(pa_)Fasl_bF<0LFI~*Cn0Yb-vMq10_Ma0{?tctPGAEyW zZ`!un6j@6f6}?n2BW5YxOY}O8!U`|EcIPcf#>BLN6x|P*wS5!0X&;%}+9R})ztw+< z5-9?nuMUpzVwI!f9)87%@5bm~g_|y&{&~S&F8Mzd?tIDrsc^$3|EI!@l>DC;+-%AJ zsfI%Q&HZij%tVz#tVOcj`i2E!j>+lx&N-Th_&K9N*)M>?qkd*xjgZD4H{}sijEmnUNK1 zDO9`bVuNiS{E#Ak$b%j>t_`QHAGSpbHmC)L@=H=Q-|EPU2=b ziI+TCe!(;2S@6zy7lavM!8haMG2NT>F9c=+lox;x&IBp%13okpqP!pY@JyKU0pKGu z5y}U%(S^F1I?9K#^)n3~?hGe|Z*WpX4$e241-k9!xO?!=u9+B>ih@)fNHv-1RZ`7V zsve{owvcL(8l6w2m>iyOH8o02aMxztHM{OwWS=vqGvA!2!_36Zg{9UT!c6r>l$1^>Y9@C@ zOo@vr^{SZ30jCmK9twxwyP8qNMUuakkz_?&*gWT4M$SrNDwoe)U&t>h+3RAjdNrkr zkbOy!B~Tg$H4AxFPNs5F(t2-VK!jo}XW+f5oG5>|n9Za!DrCQsxhChtTxtP}adByJ zF|Wwt_56}}HFZrkWQj=RMItM54$FTi9G;z>4flr&;r>BIR+LOW*C&c%ucD@i0ws_# z>69vC9vQx+L>L~5Vu5}256bXDwXZ7iS}NN&n8~RVW8tu&^N{#65Teq+u2%7o;gGn0 z1X3BdKfL-aGLDXbIMxDMUd|acvKnj1;(DQ_$N=rHp?9FudjBkLlrBCHhZo2H92z!G z%NPeHcP*Y^YA?&_{DQ2w+G_Ayw3gLgqUdfhcwjdsyXhYs9L!xan~yfOEBg!Emi9&N zn$e!32<-m-kmzjltz@>{7x<3R5vN)^g(|W85`mF%ewEqrNykA;PqHKl}#^YVZbHxR7`=yFm@jsp>S7k-z8I$1StYeMEDMX zR8GkgU}K6(meT5|N)terGM%=p5ik+s=GlJ(^DD6LGB?A6jrU0Wj92o=!o1HEm%Prq zIpHI2j+cZR(HXzwlLO9Y4twX3{9s3dAx>(Pf*=)|rz;KWw56obX8EPC9GN#)AY2FP zKtsJ*B8LW_Q!`kv=ndaY1C*g(sbN5cp|k}L%KgtGFLJ=J)mSRXTgf3!9rS>036;wy94B}`A<0&mQM4!4a`0B zhFf<$u@bogKkqT`fVXtthSYo36CUNPXKM{=$jrmRQhXaU-?Oy;f-mp|E@^dU!2@@G zQ(oer`kv{P&XZtG^Et4imdOxPtjdXi?o*Z)Wuga@$xJSzCX;$3qa;nUs0-@#MOhF3 zFqK`B-y|eY^gz;Rq+rr$rMrSIC~|g=G{Bn)uIPFwnM`L>ijqt!*laM}*S>jw_$pYN z;YDd-iICLw;Zw5m6E(j$yqqULQ5I8adDu0otL_}rBD1iV9k!}6w0NB~!#h@X*XrgW zvsx(-QO*PT!y5PRLg?pbZk%}}v^*4AZvIef-FK($i=8E5TocBN!gx7QcjITpKDLSS_$BO>3a&yPpsT&u{Ej!`stu^?22I@H7Ph@q84Yx@Wkj+IsNQx(R+CD8iKOME-3-b~f6WTl z7S4d8XE7jaTshkK@r7GGpA8lxy+xsy?1t2781nQ5iXMf<9ahj3S}{cr7Fe4k*5{?016Bhv$K5On8U4XM`j zcD6<#Pzb@`u*&52bzv%dr&o5exhLlGDoo)Ea%yfmFUjxGW#F5=#(YF)V@XDDi;11Q zn#oGGTq3H6GfE~0o)9cf1mCCEyHyc%i1mi5j08=#9)ndalbp|Iaxfaei;`i;ET&*- zQt@8?Q1_^Fy8jBSSzuNCv?vN`ft-mj9ob|L@MJHNeMowN=uuX$Bo)`|jEWeyyR5k` zOG!0vnjPbH4fL&;&fcYmbMmsmE7%E!uK{0dI)$u7c;%IvGst?Lh1|+N0$JlKQEtyb zN$6U;bn}hJk=C2p&(CN*hf0yxwaDugj`vMaV=-2Wj%(5JqJJD^7C!%xw&!pu@`e_9 z17(i$ROo0adQ6KREBcRp-N=PHwP639x3q!N<$;6Rz?)CKeCQ1SoI?t{&tLR)tsni` z&o#At(sApswsYuyLn(Gliyd2=dK_ueB3;G!kq6;&e6JQCg$&U#$QvEQ*NxGc?J`Sf zROoEQ!46R$59Y@+eflm(;#gIO&1p9q0jxk076h35;8&~HR|d9hvMT8tp!FF}Ca^k@ z#*fBniF5&tSlY4z@EGSWE!O}r?uE6$J7>CkCSBx+2pJL{(hpw*JqR$RXg~+ia!Xm;DPGwj2zOX7(?pLS^SC};pdsX_~u8~q?zZThF6!y~!u|l-q5{e-TETQ3K zav?7*Ws#32llV5{Ng$b&^66xf1VN!5w*3*hh7mNUq#FsHmSKC=5^P32z`RR^V2kO2 zcbG>=4&t3~IiJsx!HR(Wf=^SA3Z0=UR8umuf8sDgjIha*bVDLvJO7H({=+^VYu4hP7B9 zQxfJTk7yHTZH-hdBO|;wwyt6UV%5aYdq>x&w4T@S!RVI~{n`d$osOd|4gc0<8t!NB zKudnFPIvSFRi_(ZispH=_4ERFyw2&K+m8L5AjQB*Zv}}qXAXyiM`{9zcIw9OfW@yx zZ%e)_n^sA-Cr+vbjG28)kqlx9GAP5bfQZV{WyCY~0hB>kErtci3Lp!jNFfefIAvJ^ zpaKHq6d)GjGQ{6v%FK!@t3%^}2umI0Pyj|?UbO<6()lI8KrYLo1hF5*8l)jmFK0;A zd+EDGHIK4OixPqnv?!+m3Xwldfv+nM7yt;q2W=+{7AY)qh=#oO634jYsSGrMYPU)+ z$>&uAvTQf6 z*lWHHaF;y5H!Ls%Km!^hhoa!`{z)}L?dHZ*Z1OGJ2Zii4i+X2{E=ISL@y#%;oZUz#63#Oll~Z5_F=TL-6?4|^W(o>I|IQlyJ3Azv%Rv{zOankt zhWg;xpJiU+teBq@XSY|DS#jBzh@fM3%hsI*AW>rPG0T^qqq%C@o$6Xwz3GGrIc(8Q zfb5ClaaARm%S$TaaM=E=j)%}z7`$`ND_2dslG->{pzwgQS}-V@A-u5ZnATL*v|(WW z80I0KQ=y4-b}r~8T4cyV0}RpjL58S8);O$yjNi0?n>7OHHp*7J7pp5mtFX^ubN3lW zbe~zrs!#<>O5q6I&x+_ma;%{B*RW1~mq-pEH|Utr>kHl;z#*=)HC6+zIAD7%5XNM1 z8A!r{vX}&5ej%kM0tEL!n3ISLEWr&7@-7mD4gi&5^-KuN|D{mO+xVDzeOM|O32V&p zLuYp7)!I#O0VaPBLdrQH)qqJkSg!>;fqon*@9x%iAFKp8U+Ys&@cEx3sZbK;S~@9Q z!cdd73u_mc--L*WD=0xP-3$NEUW}EP?mk6JE<&E_xP|)8D}$Rt1NkjDelLfCknj{&@ZE+I-GhI2{es^D`M`#wmpd|2P>+9u9t)oU z+yN&6y#VjPaFFpY#yg;n=2eC3U`RtRcUj1|iUZEL9Jq+j#yVbrkyx4yjKso@-Bau$ z=R#Jjwt-xnjW@%);EBE0XhxBBEQ+LoSN*%-FWH8V3_|G{{LyN|AF14R3j(QkqYz{m zGhj4JMv&}BGK%B{NC-IqLKdJx5e+hq`6iG|0;z>*&;W2n++^5WPV-lo{Xc?#RwX~0H~ejVnbZ4C5yy%agDMGhB* zDfoPuEeTT^u#v-GM=6*wd_SsnZLzUfz??_Ke^E3zRsEF9nJ+XL&k}T_y*|Ff};CHp)xkjeo`LI6*CVW3&`G zphXT8g#)zkdYF0&6x*m>2gDEDAose%O29+oDpYM>ta=KbHQ1Aad23lCj>An1w*|p% zjFb#2&&aAmeStJWnhv*vMoi#KWwRoE1h@z~iQpsE^xZ_mjY?$|jJv2WQd?3Pu!BI@ zX$JkYF>nW&7Q~1q9YGvbW(S}U&9SJG6&W1bOb(FI3=I}x6P#*c-Wm90RLHChxnj74 zB6}(AfkQL{qDfdmbnTpZ-lba+RyMYaZpMolY!Frq`NNE&Lf!J9fGlSq#GTIPU~?Va z#p#5X`4k~sFR47ifd-OhCce1bG?4+K{wzCULU{khYrLFjOV?Z=ggO z0|NGG7uVML$)!iF{SRCF?@X0iC$-kewdr!8wHVl4u8V#2-s4Eu`q2HJQe;w#OcsU7 zjkDZxEKuck#;1F|G}Z-Ub4HK-B$r>#CGEMOY%mmQc+aL2Z$tLyApH!1b_7QoK0f>? zl6V+Nlp_6Fq`xTiGs7Xm^Sd471bmT`NKPSn6A3Pg%v9wH_0sCV{a0y4LOhG0uU!~NX-tCzrqK92@?~1MDwrW>GFaP zh~zXM^mZEPA#mmPF0Fl_;sv(CZAMynZ?79Y>_te&#+l%%csS7`($!-9VCRqp~ zy>mf`Mv8$U(mSe`X*cMREpfy#3r8GdI3gnO2ya@%AM+jrl>Rq~iWj+&gI{)PuT5(s zZ`1n<2Z;`{m-h~=cib8K;x(;n9JLH=gRl{U{eoR*UFgu@_5ZjeYX|N-#w9y8g=AL# z&Uwbbf!yI~D`VxL#-5@jj^}m=Pc1xxcc!&B7aa3Il_Mxd^9|AbGD$50)&ffvED#x3 zFqsKQ1XI}zHkk$Z(}H4VetQ8*6$z; zwoeVPY5LM(FzJGlZCHZ3A#r77^J>)&f#C!vM|w~nJq)n@h<5-ozFfrn61)DmzSGO? zEmmpVxa+^2DAy>Yk?st!uq7zzlqUfL4-ln+Uh$HK?k*O`?NIw^C)*|t8u%{f1 zYQY`l_Pv#WH`@D@1M(b6h4cE@0hgBz|MtJQWvU`T5$DwR9syp(jy(osV+80p9&{4J zC3~x2$m~pOFT5K&)fEJemm(8dWTGfcFnfhwzbpE)KvV=FtM=?2N1WEJSqfs-RUHX# zjQEHM!Vq#$Uz6f}7-hm%zZHYHYCH=MR}!0dQdOjZo83wr^1xv*h$DJ_Gw>*W@L?RB zvQqq%7C-esDn-s~k@H32JgvTK_+X82{y%%ExPn{;6&vg~dT!2xUDZ9eA&Th3VJ4v} z3N(nw$w;c5b~;j!`~m)zHOS_2+68plkwk;mjz4hbjz6&Iv)3lzvu|(5zX|<#ATEy` z6WtCwrda+u-niU#mSA}6c!L@pq<2eCebVyOCz+=X=o9aC62HfLfj3-r$~t{@OsH@h zXz-8s*Re}<0bM-D`N?|RQnN!%LcIk6B+V-b3>||XLr3BYyly^meY+#tb&%b76{=Rq zRiVIbZ`Q8*;GCPa5|V(5r82IZU&W}>xK{lhz@rPk%{8x^qOR{hru^qHBb;X zX$7x4ADrx~vx{P%0xWQaU?Gs|SsVj&mmXm0@e!OEGuwU~ALZ*`((aT)50o zBB+PaFJn761nr#OjtOuQGcD(&6dbL~lcda!-@;+6ES&per+0Ah5M0gm5F8c<4eC_| z&e<(uas!e^BtJ!gs8Ikaq{o;O;y6#L+sEu2smZxiTFsN|dJsn>hA51~H{r=wWhtxH zLZtNu`)3^#10t(Xh#?q&%0B}GzC(~}+VLnBe;A9GVm(@{=TU6vVQlEmaw#^c#U|GV zp0ss7Y8!pnHhO=e)OJW~J5=;EL#FUfEflYK`M?SOcmB?zzw>eL@rsZ0z6#Ik`-*jm zvcJ9T4?pt9ANu3#6OX#b9(IqFx+k=5;Ez4*KK7u!)O}9tK3DQjY5u9FK2QAxK#P&C zp&hF4d{o!{u&#T3x>R>St2?mP`#92}MY`8d-8p#w$Hl%QrN~h&a&&F)lTgFW9mP;b z(bK_}2x}rh{m==d)AKAa)r%wS1A1H=71k2z9`3hc!r!wdy}3={r7g)In`Rs(*6(bX z!GJZ>-TiKCuGtD0?Tlm%N_i6drVE%g4p8-Q$E}uI*VdQ+LC{9dmPV$)7b}I%YoYT+ z4_guK@q}1p+A1a}K;M~G9ydH{RleYLS54yWgFlG%P7`bq@Pcv7z>cO}pPxel9nNMn zXP)V_YDIntBImF!xE?y|Qn>#?DRfE;oqF&CEp(>nIYYZG(Mi!X-46%d;jj>czOXh7 zP{X>wc~s}mEf}hXrjE`$+7s-!p8IDRL-4Z*I#sh%3M#;{V8+3Mfa!7y;5|4%3N2`{V8+9Mfa!7wHM7##ox!n z*}{$F)}haiKE?Fs_QM%|p63Axt4%4;_>?|;?i88v9>7c+$$D55$DiVZ=l0Vq-^6#8 z!!26))ruF{#NZbljC*NeSpwKz)#Z-fVo z(;|Id9$X4TDdm|~WS`WPRO~vg?L5KKA&db*7ma%Pv;55| vtp(0jB71-Aw{VORx$-XYvvU literal 0 HcmV?d00001 diff --git a/struct2tensor/expression_impl/__pycache__/depth_limit.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/depth_limit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9c3af530da2f46298d9e7cd4b7cf971cad2e3cb GIT binary patch literal 4244 zcma)9-ESMm5#ReDkGv!G{T<6be?_5Wla^}7A9akRk>ny!l@u~klrC^)cuR}ak;m*E z6N^v*r4UfH@Pi*xSPyy7L*d$R{(wB@PoRVYOdMd;K%auXFmMVK0Sa_xPozkFv{@Zz zXLo03cV~VxyZT!^t`aDP#(#1f!2JUo%@W>YUVjD5V`7ptF$Gh!grb-hi&9!D%4u1^ zKFJCd!|5>RW#E;x!ub&Jk#vOfVc?_bDDVnZ3Vc=6Z$_wQM%#OGVN>0Z(=n0UCZ@JQ z%vk&A6N!-L@C!;hZpN2LqO1DP_cLWjcbBPd(z3g(+c`Z`(&>XTqmGleOG#B-s+7yN zL-kd=qAzDYqn!d=RxjEnwWz6QoExfo@7_IC*C*kp{|<@%9?}$V%mVid-U&=F3XJ|f zpQPqsK15%+^6nOS`fFSeurmINPcvl*;QIv zqOKqLu!9zlIY5`5q<|5bD zOJHtCU(p?xg7b`(||arU)O)Un_9MuG*vc>6~;2F zsafi*xOO@9xy@Fbawbbtj>{@pcgE!>JkZ*=hLJCptyJ5=ljT)kYjYXvfdyx9T;p4xJ^@9spWz36mJoPG`cB#We3aHCjW&x3{j2$rHGEkGoRrjV9Qk;;WoAOMjd z(Qr=QBssy9H)3hU4ADsY_-Ip*C#&ys6IA zXcuVcZnoPrgxWRdR5P{_O83CNdSM4~*g>D!)xJYMBAeYX(r@qWSk(H6S6Nm)XK4bkl`lrqlS^SGLB;y4(=M#*g8Wr zuwfw8lSc+(vbp@vLB?4C^6z!>QjUIcd*k*Cd0)MGsl7dCIzvX^=NsFOD-_$+52CDj2)q!2K0zu6_K2m4JjL~ug%C5x7P_2Z0H zp>vGcED_>1WM^QQokemE$$20S9sppX@7Xrmx@W4Vjy7jY;5JTA0)g}yBN1)$>KAJp zYmYwNRYq#c$jjJJZTR9&Ea}CPwQ#bo^*x$?cmpuOZG1EqFm&K?0{(}%x$}%L9hYeH z{BUz|@P`4h0&Xa|d3UvX@qh32fcURbtpWsmuXS%;4ls(hQOCuI*5v-21O_lFY(ssx zbY_<&!cKtp-t>+6FXYM z(-Ji~!H?prJb@aeOp$h+2%XJOvF5&e5{Oc*``NJxP(FM)s`3o>Fb z?BJnV816P&^Y}xcdj4q39uWHt=x}}oWSzWHG*21XdT&=5uPNj8c<<(&djIg3H(&H8 zcKZ`gCwKaPz_7wSrXO?iWlztr^Y6Hi?)_3pRRJE_?nZO+r?YVzD+xVK=HgSC!v zLRWiQI0!H>Y0D+sc=AK#;0y?-AU6co^`r1UgXgA0-N4YjD<%8$lCh5(=V0iutFXv% z5Ir>e9tiKHhI)r7U&e?2bob%iT6hGY`|#!qEwQU5cC;x^o2toEK^*e%8N*u<@e8-2 zY8ddCuUN>(3(j7RC-G zcyBvQ=B$v}xpuHN$A{^h<#t~MY8g4DC^;s$ySlCU66ZhzpNyN&~s-$Nz#Sw(WjTczV4mA(jd_K2_!%n`eQNC z#qiPNhvBJ}FWnFPi8$b>@#7u;ZE(17q;VW@II){3{0?bjtX*vAbFbz(=~GGxbGF2s&#&Ka<10-)yc_P=T|4EYn@*s zED1A@3f|BZm`a$bkLwKy>+>n^{0+|X{x%Yu62Pp7$<_zniRtfg;D;kdxdHwQq-){3 literal 0 HcmV?d00001 diff --git a/struct2tensor/expression_impl/__pycache__/filter_expression.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/filter_expression.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffaa15ecd8d364083cc3956eaed1bd1bb330be75 GIT binary patch literal 20909 zcmeHPTWlQHd7j;!y)SPPZ=!A^>Ow9_lM+Qzlq|`*%aUc>bO$vqIz%uKLn^zq_4%SI)N4g2oTU3ZJ)Y|6reDWqThec z%d2mz^{*7MDho#`<;g7t6@V1Z= z#)Xt8<(=`&ddI!9zH#5II4;im$Ne6P_ss-mgX2LK77-4OhgjH;aCkh-!U2RM;}I4P zA{-r$vT$f7Hd{Ad$HL*6`tgQTL=H_injJ|+<;ZlC^%le1D(kJz@m6myHyKGaTouPx zuUK27qx{OXHKDdC8XA@SXT$?TF?5p7JXc2pEXw> zH>B3v+ep>h>zx!+8&K26$K=_FJQ8b7LTj5=Xl)bA)rnl2SIE_w3ZE9@Tay106okWy zoXpEoQp(5|CHc~vBIk1HY(~mYCG%1`hlI)WOkP&xl=Pacu%zsyl*~w3C4D}fNzO>w zjNBiIoS8x%DTz=1Vpf`)KR1(}kY?rlR5q29(o`QRo6V-?XXHIbm5Fng6S*`pX3nF+ zRD?G~u@h72nN+`YFehD1=ciCCmzT16v>;{YWF?tL2XlR>QqIZJgna;cMV6(DQ|XDR z$W#)YpP8XtmA=j;^HVtq18~GrsXpo4e4aPQ=OQ{8DSCPseL`Z@F<{o^`abJ=ukw19S3;fIUL{$ggdH{d#q(2+L8)7woy%X28(pFf zRcoAe)SY0w8$FR0BCsNso?X2If zQ6b(v9?F5}L35aV5aV@n{-slxKfcb5;%)0H{av^Y{Mo7rF|x6R~Vsx8TWLvxuY%7ClNpPa}G{W zxVsY-cgFjqPTOcY@d|WGu%t-jD8}Yu8>u@tKgXd9JY^vh3dHvq+YK|u*6w%AMNi5^ ze1A!7jocjc1pI?)Ox`g!+vj2(+bY-@$Ly8s+gK#dQ=C&Ov+v5wUhQ1L*<|JXD{%z_ z&M$%0**Ov`kY@&HN7(kphUSoc2K+ObCEO?2o*>C`F%9BmE_rfNp2!n}B+q4ElWlPR zSngoWJqd5fPQDD@$$B$m>HB?rKCmLL_-oeaUxLcM{8{iiTXMc#3FM(qna_f_5O^Q8d|=gdCAd(l zUZ1O`P(Ezd1LK0`fjtYtLTDkF_BiW>D(6Z5BL;koDQe60kaRw0Y9P=(z+q=d7sFpJ zJ)53MDx||0n?5m{%w6b522)NLsuHcZbWWzN&+#1UAgOE;mn()6$E3|USxTqmOg@dR z@BgGVl6ngGBt-7 ztJ^5U1svvHl3rrl+~_&aNWG@!;{m@2ijr0I2qo|d)WZ~J zlc9Gc_?jQOe3~!8QA>eWXb0(Ga4Zw7xp-KqLyS^Sq=5+8hkAeuOFhafXL!Srp3KgjnO zuIpytcCc7Cs@07yc9-hf@78tRsq0+YU#uI{>IN6P?}gj6@TR3LH`{NIs@;2w;eA?o zpX%NB5FHWpMz$}zeL_A!ZWctp1JYpia&2*g&+Wd{7$5}5&$w1qAq)!$hc%j zVJq)~8>NEzsFf&82h1=~D!34+ic;TM0j0>KbctZDVOk;(XTaxtAlo|KZEV|ZYuBDorr;>cDbnIm3%Ez0n1U@u0FAYo)KUaHN7W z?G~UMCKKESC(S6Lkh_Kfi$X4$c*q7R5+xC&DD5D+pN&#kOA+K?$%Q~@+UQZ%QZwQl~lz25Z9*~16za!0?3S1SPyfz+61+NO@;Wva!p797|c}Ql0z4B>4K~`q%*1Xgq%y{vkB7LC+vJqVfsA$ z8pvCow}re-KXUe`2GY26ERV3@T@VtMZj6;ro0onKx&|~4WqENdPiO3{{c!p zBw$|@?nfK6sHBP#=Z~sgA=rm=o2~zt(uKEMf-;0JTl0DlGnZEqiNgBoT`-b=OvMN& z7KKu@Q;l}AcH>c6F0vZ1T#-bAI}s6%B@(1H8YyT4_CBJ$`AI#%d4?^J9-0Hx%S!$- zn+64bdO@!O_~KUR`z;)n{FuYATI(H^1z40tD`Kr9Ps!yQ&MfhcL;TI0w3$4)u89$ z9LQr~E07W$@DyOsH0aU_xNDilhjvU@pXc>S-7<5Ge0Aan=SYIYRY}nQ5z%I_) zvN9{D(&T*30Ob1hn4yAVKgyT%XgZfr*+$tu=m1CzKj%Ed)haLqnE8@>U!Ovth@&yHF^yUI*JlOUgFn(0CCwNIc3 zteDv>`Q4}R0%yD=&&ac+4`I(6s!vVWG(q-pofuOH!Qy@fEXomtm7_!m9y#lf;t61M zAmhgByTO9Sze1TDVJJc4Dkn6aU3$7G4rtblaWiwo+?{*1Am%jl2cud(%;XOp2CbE~vJ|eUZx=0EyLtr))$}^P0 zHp~GE9Ry)A!T{XE6iMW5N_Y#z0k}P?x2G&de0@uU?*VWT_;`qYm88!1!et*K*lShN z?)72zEE4*>IP8RI2G)H3sh@g;a4W$yGQF8p2Dd6PVd+Q6PQD?wNq>n==+wX#guE^1 zv=w#&l+tPUFV*l6x%4#n#vCAb+M)Tm+BJoT;4Gu`=gD7SJdSorR(bSq=khd<#xrf@dh(vtEDydYlYwHG! z{+*hCC-LaMdr|3LC~_^RHf||~wrZiRYS+^Qg8K*r_Yny0`_$jS%k3=shcrKcz~>ve z7u`hV!qcUWl_uVY3m zggCPdAu57v!Gk?$m2z^%S;SxQUh&y+0z7VUsoj_CEUC2S>(B$^A=_@6dcz+0ecU{2 zrNPVOtlw(ORsYT!yynf5FT!|iI3ZN{8tr3T6+gPYZ_9dCxxoQ7z$x4r4{kDe5Uj^n z#Ns}a>691In|Pf<7@#-7FP%TjI8@Z}AUgBlt({9K3x}8#-C)9C4`~7ROnf>NA<8 z6l#CB^PTw8No~upy6d@O$0@Djl))R|Iq(Ju@tlN)tMJAII&VEDnm>FIJn&cp#^!~w z`(5C5ahNTYO)<$lFyo;xt{axk|27xK?>wg^EsF(g7E%>QDo@)Y)UJ!Edh_ z7pY=b{kWFpmmAT8I$Y;z!bKok?^%k-u3N22HOi}}o6S0rQfJRuX#W}(B*%BcjqKa-S< ziVj`wHqIUfn23p2$}<~?oi$Zlr{u525lV_=98_wG{n#^DJgGljH=hc`gnhNE2zTX1PXHBJOQ(^MOK7Ch6$!Dg6RV$X*zod<>{ zFg#&tTg;PYA!Pa0e~dNV|CIz>mbxjgQ=a8m<5*FMy$B#4md#9&7kNi9%_li+JH zj&yv^aPv3GU9a9m?J!dy8fvc{zYW?cD^X?)x^Db z#iMt`qoq*27U}@K7A>vctgRm{2L=Bz&jUg92ObeAGZGP+I~I>zJylx0hQ1c@uZF2Z z!jCN>odP+%U=R&}4gQ?>o)(zl1ARqClrl-gW&h}{QfZ1^+_OsRV#Yd}DLu4Hj!a_0mfoT$EuOu;_g<{^`og{T{$FfA zsoV?aGY5&=Dhas3_{D3tT>~>-u z)?vckmF#x+1~tKl?JdIVTVB3#s0oMRWZlA{9bY3aRrve>dqYuU9pRtmPL!S$uOrZ>*qi;mp>_9yvb^pqAorHWikIrg|E zp-9?%o=JQk`98phKZe}Gb|JZ?H!&B0?NUL`%_qvABHxp3nPqk2mj}YbR@6h>&~R<< z-DvzyG+vDMYSCU*?B&cs6Lf1$Rlgq@etM*$hH>SDWR)^SggGf970-e~oTmq0U`PHr zg}6))zQ`|>DRg?FoFbAVa*+u6kaY1hR3T-7-f1tefhz|nbP&YuPGkxak$lmHQ zomlpYKGHitVj*A0jd5-LARXIul-f3FZGB}Qg6zF2Y4-TKm!@tOw1Feqwxi?~>~@nz zp?|O(ZSoNg620Hhb|a{@_h=2hWj~7!2rcVbyIS+ka*#!ZgvL!;Lw7mM-Xg-PjvHO? z3}{U~Mokf6*M4o}S$n;_c7ty7_?njT)RiVT!GY2V4_Rn%ol@}Q6#uS51AgE0o>J$w z_u^$A3vM6%xI^1Lrfq+LN0bGMcF=kc{IY9q4t=;=+cZpd!Wa8k!nYJHksc;loFI|y z;3Ewt-$-lqTmD}8@DkL3 zVqjDYjQ+Ft=b^IC(|8hEI#Eb;(8Zol!_jLSZ*(lZqHP^6hDWsUi0U0-%jj4(;$^mM z2N0;WY*-+IN%q(m%RO~mii+H)@_l6I4j6K*#64^-><7&qHn)m<0)B`hO*8==2w(!( zVAg1B25zs>b{s5*4{6~;s`n5d53*PfK^HxL^hxk}5Pq{h35BH2XI(EUS#~j$oDKNx z4bxA!tHJ)entpD^(h^p;>+c}NZPz=+Yb`y-wjXO)^))Q{3JB1h7z+dzyfz53d_;gC z%SU>FL6p*9_kpICnMKSa;2TrUVPu7khBL%w@|gF^2t)hh5$+$bTmwo9#h7P+YZD>! zCs&-Kn2a#v!Dd6vs&eVnatgE)pItyL+$m7G@>10s&{ESG(7YDTfagoW z`l}0XDQ{l96WpK%H^8z{+eVAA-CAt78rWU-389UM=JJHBMh6;M%GeXS3+M!LnW9Ci zOku|g6Hwh1u!W|BdRB1-Jf%jvRk53OdQ7QeOIEk}kD#WSHoxC)^H;XBBi53NRjhOC z2v4ZuJh#g|r+1;U&7I&4&w{1!0GCT_kgS)SLc7UmMk-4o76XSdi;koM_x>~Y0CO3T<6n>ArWkG; zaYK$vE0VqJ^lJs*fs&!Zbvb^Kn|l!Wxd7n>tRD7oZyo8!g)J*=L(^`SyJ+XVgY3*X zu(hu!Zh~_FXKXFkXYRFa(AxTc8q~HQQ(KRd@qU6C?bM?6?YV4gIa7*4GexN2Clwv{dujq|8`8Et&o#-jvrs?1~z*X$pOxv(L<^f0~xkVbqTlqr0@|E>+yc2S7IS@6?)C z3g`h$dL~Jj1s^gkc!%xIEVx3?3Zr+K=le1ChlY9Hu=?7<-G;4q8n!M?7aMkH4LcT} zEd?9YV4G#0U;3oE7(J#%kE!CZ<(tcr2w%!PSH4Z_N>$htP;QB@w5-bv@1wK{7pbm| zL}-x8cZeJWv0Jok9}#ImiQIQVY!+?FJD_?8=z`21!ZIM{ zi#x5^I2mtt!mZf`-wOZ+q7G}8MLVro7G+tp?9E}#8Zm}7%idVMymo_ju4v7UTE^@s zw`J*e6kYPM2<+J76+0F|ryYBoS+FkpCG1#QmeNrPfl8w`r)O2OFG!doxB`SP9MIyJbCX)!yC zu^}xsqy~nNss-CgzZrI=z$xM<|CZl4EclHX8-GKyV$xqWQY<&Ri(#B=DQGU=FQ+s2CmAqQk0afY9%U)~~!v zBlUs6PUTkW0$vK;H5bRF;*)e2!UJJiXts7JGLZ8F8nb#JekNF|FH!-v{c9+!+!-!XOjG6)A=7@^y5Fk z;1@_E{I@-m_3O30qX>UrFdzoqr`^gX;V%37b^sUr88Hoqr{vO*Q|@fi<4}H%@Bn zc9aE#Zl^v?9J zp$C-ksFJ$9(eu=8c;gP01%yg__WlfmL+F0@fY!aI>_?EL*vPheWNd7?eqL*R>gGjl z;HcVqRBJp|_96cMx()AOG!*LDq4n%#!LopN8|fVO;H-8zx!r^DKcIw1mDD+p&$H`h zQ0v~0tL~m%pA7yS$H@q|NzCIp?77*c^&9|~BlyWnKYxXXs?g95>cMid)S+z}c|Zw| aDydVRos@eyxlzzIVnUJds7fjmDgOs+Z&rQ) literal 0 HcmV?d00001 diff --git a/struct2tensor/expression_impl/__pycache__/index.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/index.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d2a44de8f53145ba93db869b46c736537b64c02 GIT binary patch literal 10422 zcmdT}U2GfKb-qIm|AwSQQI=%Mb~2`I*|haz(_a55YuUU0TRVng#NKU|n+2EV3?*tw z4tr;~wipReQUmcrAB-aD`XMRwu-Vwy4UmThFp7N$k^lwT7t)x(#1uvrD1xGIH0%Z1 zpe=gNz2tC)l5B6#edvg~ymSA5&iT%F4u2kvN(7$8)_>*YT)H*Ss>&? zqL4IE1jSPnN}jZ*8FoMjrH#3q%>XMU*(qu%$Ec$6i7{fq(v_hghpYXtm`zt%jA*RgxUa`uv2< zt45}x=>?;pm9xc6p{%GkDBjR46K`#)78kQ<8LTL6s|6s+QYrzqp-CFdO*<3dIxiGHP|Gut^xk{#NxG`*&~Zloo+QgE zlp8!N%O7NmGxDKBhoF3AQi$q5-V{L~UjHjcDZP>YKGWwy_@$OoB`arG~K9_xC# z3i#1}`4e*uVxMMbfzP?H&cSz@^1ZC^bHS5dFWxV6#AS6p(|gQ+&MCXbnUN=-vy=6r zC(FIvQ6Sfs{`rUAGye02-%O%6Uk;OVru!X0kpW5Ms;=g%Mft{pT9%izs(d3`Mx6%r zuORkRD_YqNEm;}hrl?Fw=QYszo7qyOs7}a*ysLM)eSjERY0Kfykj*5y}f*ZoH@-vq{#yg4QPJDQKA*9v~C!tH<+ysG|X6(_WO z*Q#)^Nx>S(WsA8gps$+ISTh|8CB!8Ttnk*?eamDO1xiFjEcgyUtFo?9EYb@#beuKz8v)cDu4@J)rZ8UsqQGAHy+bQmc3k5O zvr@2pM;nw{KwQ2jkpC&x$$i~3Os?_b+X@0OccY`@(Ei&@7}~431}cEnU4di!K+V8n z6el^W!GWkRT1W0@;9AG(Ni=$tL$)(Cb} zR6Pff1o)KEXhx+J#9E{g7=}hM>B9koxH(*gF2XAw(VyO&UeHSFbVVsuDa|fTpH=nu z4XrYLL!zGi&* zQGDuQeCpo!8u4$K@ozlw3jHUa5lC7-F&cd0^+};;1d^6S!h@}_Pa13ypLg&nOh+ui za49X)d{n2We+#nyRVpp0s$c|g3H({a!)K2h5Ztu+yP%XS;)>4*cUptH0^FEC%OWVY zOLg1r{f6yAKzkG;_CR~DG1z$nSP<7bZceula?3t9Ky2F1?{4Q_OE0Xfh^|%HJ3A&$+VFg(^CP z?feL9hL#M{F?j956CK9d5F$?ofyEl?1rTnLLu>`iID;|9d(u(P5zg?TBREVPH^AV? zb`|MvsQ3x|>klD$8Ti?AqUjIZ-u7d~9G+?nA2f##0^*FP0CC1s4Nqb}&^6tRZ+jF!{4jpFe&k{!e#wkqvP*sm z1+!*+_KDBae_D7(Fm3sKaey;_KfoCztI027Br;}(<+Xve@7!Ce?>^B8&zRwv)$wMi z|IXHWXteGbrRe$R;NN~aiuSYxMLz_PQ%PEYc!oz2(q6@*ii}w)rdUWahu!}(u2TaViY41wSPdf+?AxZxd-;ikb^5O zQ|Mmt?aD1`f1wN(PywC}tl_!}MqDpkFCybcG-7-aEjrGnRaDBt0e~>YVw>_n2Sy@c zQefc=Km(8#=twNup$I0d$Pc=sS$@8}oK4VRLufsi-MR^x9}6E6!zPP_xGl;i6jNBV z$1`k;0@k8ymt2-b0QXqxxMF#X zyyb&qJ9@+l04YI^26SH3)ncBG!`iJ-CX*{>bv={OaYM4a%5<0RhnL!}&97ssGv9Ba zp*{)8f2@+nV(9x9Ze4gJZh0tfx$~AewD%YBPq#M2lqsg_VyZbfx_b83t;}tS3Tc`YOsz5c(NCqvbpcl%zkp;(+BAQu! zm(^^3PE*utjPqLkXONRS3L;KVhg;19*im{5U6ylyo+n*6gTErXJ{zZ_4 zhw4LbG^C5Bbg?d8ZdW>||4Z+R~fCLZC8cVg* zOYF6$4+`;9bl=P~%2P;N!M{17^C=uA}K5dII_}6LR0f?-Hlb zv18}X?C1lo5MmSAVr}9j_rqHM3l50f3E(R{bVvF9_M?sPF*AIuE*|5%LeROt1rZi~ zP(WBIlPPIRwTSshCWCXfYXX^!qUACfibrEfUF*g0l!{PF4`ae)96~}J`KEY#R)B|D zbc%{>1+MVW2%W_@@w%oJ>Fe0}JibTUz@1Zz#Y_g!@F0}3>psm991Sl!y3ep834#C! z-T7;IM6Vx25|W6=J9am1?nt!=ps=r(Ncg?u0p{S}erziDKT!01 zDyMjz)F>}9Vka_U=i7{!0j=`4iCv7*w~dKiXtVx6<($MWyeYBs^$@!VBX&MU?4pd= z^*L9sV2!BsLrvT!b^{cCvLm&B-k8`qe8LTB-8!%21?&!3UDtqcm*Oj6da>R)tRQo9 z3gwC`YJqBu8v67{ko2-YC#o&?;M$r4@t|w_Sr)n>MZ=C3fC;zhgBu;&To4h1JT&*Emde-?$m4pHO|-R=ac;a23X*UhLk z{1Roy$oM_+n}ulEWVTqCF9UgJK0h*V-h+(-(S2Af@NS$Bn0{7HP=MzGRh5Jz)aQ&l zt``X&DC{o4jw&(+!m0m0j90KJaq-Chy6yRblRO3qp-XhK;qLjcY-MfI;uwmxn!?I< zzd7)!lW;FgKRE0oFQI4wDe`1%VBM6KFT+L3ZyR zbI0WBSyLQq4j^m4v2}A1WJTT&JMiOghKe|*8PuIn1w)nY?$l%62eHIqZcsguv#kb;r8K&-5 zRK4FBa$L`rT_gIPz|rOY&mFPncZwef&R-+C@UiUg@_FzEY^ZRLy%N~ef3SN%?=I{F zYQSRPf^Ep^^%audxmvPUz2p-9ZpFNJH^r~*VhYE9aoDU;GQ=q%@+Ufj$w^F5xh(%Y zMAg9+_Ol^$`YP6-?tw<4tGLnCQ`9v}>dxGbt$|C;(O`X8i_x*#ffq|?hi> zg@m~Y#5hx)VCT#SInU+t8GK_7`mwP$j(5&o3wj!|^fgR+ z6&I;#?ZnF;OOHnp1$Co{QygTF=YIvR1R@kY2Z_^%V1&AM_{%MwIS(;dwA*<&*$B^? zVF(w@@_nH5KrdkOZJZU_lQutROTpw1Sg{{!T>RXJ{2Y_M81U|zx%br;fn0NRmpMAs z@Bya(M>{UEYlXTqs^zZdQv5*Ja&LN`yx5}T&ui;Jc6f;LIV7aY-%B%Vvq zPuT?h6esAY1>P4V{A}^?k|Xq+@MZapr3yp>ndij)EM!ks)b-05^Rk$4#zW>9X6LTh z4p+OWZuE%#rpBZ0694%vU({}JYknDL5U&90SmIUqF@#4UcC%8%Hs*bC9`G8@gt1lT zoBb*HDt91f=>n9YQ0RXLY0E1JLTmeiAc!r((raEKOxDT7W}nZ=fx7$CB)jYGPm}Dc zyFX3xwYvM$Bs=TwPm@g5-Jd2Isdqjt|FAHKfCCo|xnCdo9pDmVn#sveM_XRVeZC{H zHe>EO{$TfK$@;D{=8m&19~84zjwH217`waQ9Gh+t$URU#E3iy+?}3k&T3#sne4D&B zX6`)tK>RFJ-}#!k?e&%qidid1a$1mu<1Mm2SsVN4>rb%!S-0}MkP<+`)+hI*Py3%> L`Lhiw+5P=5O@uR6 literal 0 HcmV?d00001 diff --git a/struct2tensor/expression_impl/__pycache__/map_prensor.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/map_prensor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..208eb41f0367564f783ee9d68ee8157c1eb87c2a GIT binary patch literal 18599 zcmeHOYit|Go!{k~l&JUHvAvRCq9u!xEIYR3XKdMV{7Pb9+9=I2w031vrbuR&ieptm zmB0tffv@LY6T~>A09&}Ja`6Yz&(Ed z89sI?*|9H8TO7KghO;yOnfcGm|2@N>`Taft&v@>?#%gv5!hg_9;i=1K?*ENL5Z)DJ zVMLHc*%lQiY$LV_`-pwQG2)nTjyNY=BQBBZ>{0iGXT-zu4&=QfUY2(v?;G*4ybJk? zkqVY~BVRdE$?~44f1+xnisil0>X91Rr+CL}3&_Y7O2v2qBSEfIoaXn|0+=BC6`Ne8 z)Qr~`ipFfF+G?diu0d^Mp;oR%KTXBAI`f!gHFEtG$4K+n#H6Lr``5yx0hqMfgrkDo zctwz#%y_a?b)Fl5<0jdAMhI*S zw>bsjpn5KnRKx0JDV|W&a555)Y0{V)pOCafSk)9M9FrwAeEz&5OGzcB#Z^s8#-*?n zRl;M^SVW1+JwD&r3yKO*{3kEPrNrd9Xk=8HP?8tovL?mHq*y$uNNVK#g(R|<2BnE` zBEYmW*z|9Z-MMz3HhLg1JrQUN{qOJ)y*(ueZ!F?mWP%B-_lq- z8jW9y#Li3MXi`yQSn}5tO}Z3`My0SU^BJ&Jr@5T-VX;S5C5-i>cPs`LWySJIqw&~l zikj5Qo6}z9H!{C_-E66WNsJAL%me`?)VMr3Y953ZI;LC-squJH+RHX4p)v~SSuuVm z=8K)cCUvu1BAmR?^DL6impa(~bi5qkH7}W~*ltmkR9<=xr7!aq(eT7MIV`>2eR)uN zUFwo9^Wr1nsHX62a;zs3OK$JyZyot{>Im@VUqMw5T`Vi9|E?CxLe+oFi**(hsw&!K zG)eTTaLWBUe8-s(WD7QYJU%qF?0;OoG`0OoJ~hEw*!DN1F9S})b0{2*hR;P6!+S)H zPbsl8O49J0PLQY!M-9&z<<&_gHmVrTvy+La5)h5r(QtHhG76pz#Tl4}LLYGlms@zN zD!qlk^(b)=3*g(GXEL6Rk1M0eQ-*_>(C|_kA~?fIsb+YKLc7PPh-e{INsy?J4X2#E zoKTD^VqhkuL$UCLVz^W#IjP1*-340!-(LKa6?A_Wl6Qrqh_C!JEli8kwrTscW7;|G zdLOMmEVh)zl6?qmVhjJoCI~0P2@SfzB5@$EBSZ11kGWPNkS%A@=~$196d!S zHd78tvjtsZNg*srTwS^l8NI-Iv5bYmp@WLc57k6NvbEwJMPiA`BrmY^N0Pa?x~x%g zWd|`KuF5c8I4LE=7a^~yK{TQ@$L+I=Yzo#bOtdm4`m_2S+< z7|CwvP*4eV{E+#ppi%+6)f`%(#6ppYMDzoNC6O+x5ACB)e<(=Xe>) z38H=IJ{kz$2=D)x^FBDmqCdZ3*0+rHhhhD46duO<7Ff=Ue`5U-{r{d>pYKNA(O|r% zlJ8|^ZIv_}XF&_Jx(?;n2{_M!-H*(IDn%1U@n3u@^11uBp?I%C0V~3Yh@g>87Dw!| zO>vCFOXH7hS6qsF+*5dm`^~~o^bSYfJznTbusDwnrBZe({&A9P>>2aS31gz{x>7k( zExVN(rM46zDk=3bo9uzJSkE9#SXdyufN1a{u!2BS4KOGMh)PfhI@OqGVXE<)XMpfi zlX-R~IJaePKDqje6X7MdQVASufI$nwQ6>yojYFS?&}K7tOoyo7hQIa6JZZWD@p5-p)A> z`z2=5cT7SaOOkhG0w?IOq)ndNmQ)o*>WuU#Jt)YJ6yOfH1~nRmPUgcyAD@6nekv|2 zXStImoh7GyEDnPx0>@B`PpYFpgbXAYv`kNSUW7Lu3rMn}jVduY90Oc@3|+kvhqE4u zy)uZV(xf(t2A3u6LVOY$n>(;T2QDL7r=2}LJq)D@7)BE7(cB?=8%3-@WiwrMy~w(N z{{{po#t@gz0YQn`A%#`r*z`ze6b0RpO3cri@hvb?v7}*WZw1JP&ns*!!^NAGDe%!E zJAoo7Z^!@r*O0s;TCLyrQT8GJ`I9o)#f36KV79S1PDTZLZj08aqDU6B5TY^I18Z6c zAuh}xSs=^#6wHwZkV!-)k3mZ+i4bNq5l+JSaB7nX%NU+e2xe0<6jC>!rQr_osTn>v z82PA8z-2fzB|4_I1EOrshBp)%jfOQX6w+L%W`@fw6L0D!yicuLbz^!8qy7V6w8Kcg zfW6@noW3`XUOBqtXjpVKWIgpa{%oO1?>mvR3(m&R1e??KIVCxkxP*qb*~3>(WSiFG z>+~%Cmhc!F zJh&ZgByHAIf8Nez48)l#?@vS!Q32&6s561h(vk>)#Z_5RdsfVo49WXsQobvk$1GoV z5EehMU3Z=lJ`mJyR0M2>t*6%zFKRR=I1nUL0IYv`k9uf6_A~X*&G?y8gHyhr2VySg#v2`fZx>s-AyVUyB zV(U|x*28-1Vc?$e59|Kn*&|tp@9Ob4kH2y9%E{T2*^1iPlU(hZ$&Hqkk=qn9S8eY; zHnF`0PEI`n!Oe(sHYi*1if39ZCqBmg7V03;z^Fq)h6EC$58sS<%{FavxIrrwXE>#B za6Q%R+nJk=+uM~>+MZg4fNj?Pr~;~@VrbxX7ZM6*?O z*P}~Sy^B@7nW}!hs(WF|MRY@wNO87v-278O_VF?W^xbPtm+mt3b+`>s)NX@LzL{H zWG@n4HY&FM!QZpusR1{T*1 zWY+D}*X_LH&a@rS+YVr28P7r8a}ZvB{UH%@w%vxr( zPHkF!3CqB;)DH=fSphpdH?VR`i;G4KVU;jed67pG)}|N4k3hDjB-NwTziG#7xuP#k zU3<_{t^oX$y5dWDucG1F(Dk9Tqlw|mmM5^D%vcIThe9x1LlZb~ic-EZ6rwH5*SJF= zIX((Ym*UcfujE##L>Ziy-FZ()C0E#R^H{uQR9QVB!*MPij~ce*nBhEp7U6Ja>#9Wg zDtXF=<5|QPIk)eocVEeSQb`)C!<0Nt$qA}=oPlkwo~CyPCC^i#KW|!(D$!^t6cE(_ zO4*8Xb*4%LN(t?Tmd08`B^Ca1HlH1)G9@h@yL59zU)!G(kju7pq>kP?o!)We(@otu zyST9@8{DY}hbeo6n10PsF=q#0&LymE&$)SSYxnIIo~J%oN99R_!(NwancsLjptld? z1iY0em3Dhi>cae#9z3jX9?l7PEl(PRp8i~g$KFT1_bpd9-*oFOTlMO$oRbP&LgQLC z6}@3Qq0^W12sQ0`b!W~?FFv8J?dF#Ef_m*%-ic3ma-Y8QM5#yCEl(Omdu=L7@ypus z#3kBWQqSlccT+{n9hrTVCoY@4Cfl+3Rv>4mT<@+sZTjw0dhcoW4ll)CV@f=ty*br7 zKX7}u-oAsnYJP--(>QIh;pb-o`LL2jY!F2;;r92)K)8vjWIk&tSp#r;h^N_O)#2td zml<1Zjio*H%^q`m*>=9nN|>=FiL>%gDY!(Q={WHQdth6jztZg(s0sVj&{^{E!_pac zGk}K~*hzK%fI3_jg6PI!KOA{w^rF_o;_kosA3OoK`W=j6cya9@rX}K0zX0_AKi!*l6(UEC;$~*?T=Z}RxNXWCtQc8l1J=p!C z1zJXR=37B!mZB9>6k@}_M_kHQqZ}8@JWG^qx5DiZL?KD1J^6=>4(lq9+GGd)SxZHz z{S88&pgAttUbnv}ToOg$Mf8Q6cn)xjb87Ena|=7avhWM-*>6Y~oCjwqZL;N2G)O1((1KHfFC zH5!ox8MNFZYqc^J;aOm-NdW#Q{P#&^!W9slXMTNmbMpmH*shoN=uhV{))f6|_5I-xh6$T&{w zj+1HjS$^zBzjk|d_qv-?8TVG*y)`E|?fY2*R4X^6_AYc~D)#9W`*Ib9iW10Y1pIli zMX91J_r(`ql$DgdY9zKkw$~hz>>o>xLmS;6ZxN9)cVe{2f`BP139>DqBsAZcRnoc(2Vy%(L!90u#W3{ zfSMJiebe4ZnTVP8joBbH#HlqnHz(YE2^E(-Z(Mx7SHO}^N9-d~87nXZ(u zcMshaV7o>(HOY+p!$n6eJ7+hC>(w^_hs(}_rXmA^sS;`$ma4lJtGnipWvX}T)jMam z-fLd>_V-d#3wtun2lVCxXUc*q`Nr1gDduG+5J zuR5+e$vI9IAn;ux9od#F5(LVNOexzf@I{z0*PL_CF=s~{5L)C)<%~N?OO=0!krCW; zUvperamR+~k&RC)QCLKUJ1BJ6K_?&@_dU?%T9$%WV*@SDJY1{XO8M~RurtR0jT(Ze(DA}{-Qm`u4$i!TZX-8=_ z$BXNC&2^Ejv3k_JZn8a9)aiwZ(zivIuiBrLQO)qz%9z-y@vZIN6sJkgJEpo$C)f!T zHYGtfTk^Dylc?f7m%T`4zN_jN4PANqF~C z=XYlU?R)9y6K94v=%e?5$Eah$Ugl8bohFTG6g7d))C-i*8X3M9=p1#Jgq`8y$76;a z7i$7O^$G#)A)qSDdT2%sKaU`N#V_Be+X;4#66%#->tRxmuKH*sFq!l^FpGx^Fc)7P zL_t0hC&~7Y=;#Zq8I)n`+NGA@VoPwrooU&nx9mzg8ncC!#QH4^Bgr%l z>Wzc5r?Sn`jmxQ?baP+Y(Qwb-mR{2}|5A2MV7^^n)2*-Bb*ENeGj!)$kbYGx27q+- z$g{`!=l;Nf7l#O2{W;68r;?Nt@8>t)@-X zAD}{`);LqflQ$^UW`d`rnO2L;513H2UNsf=x_6#@8x|Te_nNlAjLg|7J20dVJj1fA z>JgIPwh6v9HwWO&4=?PzGxUiQ3yFFAx-eT`7rFIa+~)9GE_WEsI5fw>Ic{`XFQZ%R zG@24q85EVGQ<@ovMfhV{{8{~o2+meXZIsypCBZl2e5^igkJP7Y+WAJywCjx~y3&~! ztkqV*S}y2fXJCgGg{^vG+L1D=&%#tKS4T^?RG?~10f!=R**WaA?NOi}CXnfbR#Pw< zaq)`heUyAyES+&%X7tPhXH;xqHULWZY*G}bSn*7|;V#<#!lVFEVw0Wbk||(mU1#eX zX5&y_wz#llUGK&Bud`BqT8BkoB_qu z_$AG#WU%;U#ifX>8MfE@jOx5wk8oIo?l`L&wJAoym?8{#)wijXd`a~kN=WIcBqIzD zp&bHlMqNpK70Xq02>^RZKHH(#A)9?dO&u(AfA#ej zm;8Z6e_(!_zIE@NS2O;DI(?49@NRB>`|?ur_QmGyS=U|Q+ms(jO(=SI-Pc%&N{0pa2>qotr(+LfWB-_}r)EHcB3}zbp^v1r~=skaf?%$9)v9L|=Ka$-T1lCogk*oT0&O)P7 z^UtV{&x^g4D#~&!B(bu&3Cbrlv@g|fTdd!fsSoP)L4-Cd>hy~CRQG~k?>(ID*al20 zgTSORNUIlQL5yFD!Qs93U+#4r5#7IXh)6A7`X99`75*0x@=5*^1&b8;(>-NKg5pmU z%9e)2#Q)2C%I15RL5g7D9#=$_Iw`Etoe1!s6NJCwhYYxnC^793{{#G(!En&+gRk=X zcR_~#A-w)o{d@8e*;6uBP5!;fP=tS9JdY@}q3>kT48rY8M=7#U>d$H%nc@FW{(Z~A zL;|;8ix(~FiYhJp@6wr9uVQ9KE$6(#o%Dyf^HxX-6qfs;rT3CM9|a|-*DzGcUsvD4 zYrxI@a7(WmGXP31I5Tb3zoga_CxhdCi}Kv@E^B#%YM8J6Gs-=pug$#0yrXUDa8*b9 z0+}WY=>F?F+U@C%enhD|dKYT-j_nJ~!QMwIc;7eXUaNmi3-d!HtGLHxqpOshra7A2 zfwL^1co3!>=A2`o+61o0^`EgXE5&$7^!wtlk0#5(4YlI_`R8wkaso1+;HnM2vfFojv^!@L;BW3>6XKH&wtvw1?1dH zj$2J2=Vo9YVb`9VH=i9AkBYx|I{)eu8e88!maAZOK4C-Wyif1kdnftx%jwRi^^QV~ z!0OHEx+bwJwN>wUGAAH+_wX-{u}pSD*Zkqz$Mg+5b9R(3ZwW5g7oPtplzjw}oc;NW>CH!U>1eKkDrhWI(*J}Q zoVVSE+lI_t+#`e_?%1w(4CS22V$5}$^tHIsL@xwpO7?9jP6atW6=cj) f!Rv~tpe#QXp5s%Yd|@h-Uv(-}Av_>q(&~Q!Q!>bv literal 0 HcmV?d00001 diff --git a/struct2tensor/expression_impl/__pycache__/map_prensor_to_prensor.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/map_prensor_to_prensor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85689438a5f43bacd8f5d019d359a7325722049b GIT binary patch literal 22795 zcmdUXdu&@*n%^aPN%5igTe2))$+9fkGHokPV#~4=`5w??}yf^LgG{%DJ$3y|s!%)NzM1H}OCe>Cht z7XvKN-*?V^Qk3IlcKW!Y4$r;keeU=EzQaGPtBVS_E|>pex_wj-{yW`RSFM?O^hbU{ z_(YI}DM9wh{;Y4#KjohbOaffT) zPFAZAwfa$OV2fJ2l(4c(*{SqR2b?ud4az&+^OgtP`^@$*B=5p}ce61DF~)9p{fhYx zv06i@6@NmlxE#J9#P_ED0$?((UdrUvl)5D478Es=&*bJcX{WwbmxbPY>WOwB1*bLv}CPL*Uu z%gkVAQhHI#W9rWQqWM|Om-mD@<`xptg`AYlp|=!G3Co^lgX9+T3yU-XNmbrn#4P1u zDV5Dy#hH1OMqQ;?_$iF6ypx(+$ST8S4xrJejkjhK> z)LXO-)@E8=Pv_LR(P;F|H{Xm(l5$0v&*NL#FTEa>rgJ%DFQrsg8cog1Xq=N5(_8gJ zX@OMdm!Y$_U_N@X={ zZbi#$aX#snr`dMTWU$=;9uHswEHX2h`BYZAn#s?ySM1#pjRP>7{R&cR=l4t4d1vFJ zd^@ohbE$=-i4)0u&dSoJSVPf@7E<}y#H&b#e%Nmg(LdbJI`4lYPBl14@QmkOnX7t@ zy+2GF(8nN&aNsHCzR$eSGp&I(l_$<>2%y|={eKdviq{x{@%beqd;;*Ex|CH6aUzq> z85tCJfDIXZEU-nG}Z06P^i;5@* zm7o&3Tw^tM$gCVfN%(T4N=Xe$q9B&mC}84l*5dRr9hSpjY;{V#98tt+KNzm~89MyJ z>QtC&U}am(0WD&iS~Qw1s?H59YB#lL0vmVD7Io^u9-1+y7WZ@4y!~yYNpc25 zf~rsvu-k!;1fMvO==^MMab^}BYni!BHl<2nQS&LzUB>e?5e*&CgwYex_xYTZ&MK&Q zHMK-+MtKMPFHNN1CT!@9Xe}c}xpp4D>0tlJYrXr20Op0DL{dFPR<3)WnE zDR-6L5Ye8LRM{Rg@HUsyxy7uEg-DF1XF(=u1{Wudm}`fOAln?H-rSgJFa$7kBWy8( zbkIf%0rIIJ1n~dpGLlb(72k@Gw2N2#xC3VMYbAg?s?V+!@jd^FIPc2{cuGC+r^02s z{(F%X;j&$#if4rPeX{QYL~KB~Kh=(kCm0jM+UJ$45(gY9cj+=b#RPFz1z+7_qV&Qf2z(D&@EAKd!ES0i{n`JlG( z=9xlm@49?@9$)Np$jZ>|BlIU2AyomS&!=q!vx3_fZ^0{9jP&iQpg_}mBPrjlOu zxlbh}NTceKN*WN`$mM2p_O6$)O!Z<&R)q(#20hkO5PQ^K+&}%eY9Fd&WYXHyek87O z4Ut~;a?BiK(P06dYAVb8yXq8)RG`%&SyE>no&L~+qIS=(nTkbTMwW;JIbfXbGtqS*J zt+y_I{G)qqLmO>F#kM_q+n!=qMOrSzIbI8CxOI`pm=%&x^b8H7AB z`mROMuEpabr)C3~heORk=|ZHKrSc4OITA6zrWdej>uZupkTejYdUM0g`2hf7)A)pzxs}?utLAq#OwctXYDo$hkBgiVa-R!qnw+9F zQ8^_S@vqGQnpf?=vcEnMDhrhK3j;%?z57eYzgK$UL^;wDc&;oU!OoNe_y0OS1hKa4 z_XW1E9V!d>a*~Gu6kR{e3NZw))_&`ywbwrVf!;bu72BP}RlD7;Ev?P2FYCLH=)I#< zQ*shlb*UWk1$L|*`t%4D@2E^rIK19<=g^l&^uA}QaJVuFMFR0#T3Nu?mI;3cZ0D=- zXLSeOM_NQ;T{bFb@dU&)umCyhnkf6O3oGKfRm=13C07%I_x@TiZzHz=E(kf#zY?s5 z+Nm3$2)10t@e~qPdQsJwe!&E@liapKlG(4#G{H?6T+%=GZA=r#+@=YqB}@~Tt|L&o zDVKD0Rv`({BO6*`CCS*5Y3N`TY^jnh&Rcou+$d{qK66Wv zu#%=W1C_C^BR-POi|t=?qip^9*Bipg4PG2o>}ymEbI}f2q{RgeRK%PCju!RZ?wG* zXuamU=3nu7xVzuWuYcQrvmHEp3=Lt;XiIt~n4DFzAgZ4S41eN);d{&QUC{_x@X{IG z{AF~E)|$@EEgWC&^E6G2Wpl8xXvY&ymAh!F(GU@B!4!zXFYYNX*KnAq<8J@ zVx&)x^pyjiDvVNCe9tSz$NfeSMu0huk??q^8YMC6te2OhDVw*Z)}U#$gIEkrdsiCJF5u?3eozY&f&S-C05Cc^tHR5)(=Z|O(m6BTCmQe)ZGvvRp z?^P(M7lb^uHZPeeq_=&)#g)v1U~(_9*UL43_j(y%!cYkc#vqR1-2!pb)f&`~ z)C$5YAi{Y|Qyg|0YA|$g#r#{Mg+q<3NsO8(XU~L;=$>A$&SuiHFko}DA51Z`FqaeB zkmPA%s*ARnHV4>aU^cO!qZ&?{;t_h_coHyRRL{Utl7dZufE>T2{~BkF*ye{rIeP^6!GbpMky+Ax%0- z^v?)sxkVM`%vv%?YDpq>iflwp`xPvjMjdwznDFhfYuEtwKQNOIeIJHNC0R#)^YxL> zODxuHG2dNxrGlh@qq@LuB?#+?=ym!0X81<NJiC6XeJdM`U+lA!`dRZgh9r!aB?d0Hfp71005I5QiR^mAP{i z<*e94B6tncp=ipsmX{ITp@BzG8E1A$XHiw=kI6GkVnSDSape-Q{%3wuVA*@k_fcSd z6CI!jD;0F;BmWkZsG=T(n_)}1@vEc>8qKz`B$-o_imK*R2o>WfQol##b=IbFhHwtg zG}1pLX?2zBiW9h7-m`^)uZq+of@o@|(df@f>%*(0yyuBQ9Pvyu^Y_uj{ z*Na?=YOlUaDoIg)ihu2ANdBDmW!fjSZok)ZXrtxO-R~D$#`Ko4 zdo2?iEfb|s^TU9zwe_)pWHs^g=E3zJ-(4>4K7^+&&*5pya}dHr&kngZxS{7>Lt>*L zQEWJ%HykJjQSTvb|6@v4dhE#RU@6kDx{tHmS_>HRs-hUN z9^UE4gBod3))!-&X5 zkS9hD=TYBs5zy`uty@^kW})kV2rQvF=LTQZ#n^)<8<{TpIesv~s zvz+TRo*CZ29M;efQ)zNTWZ+1au4L@FZ>D=G$+d3j9lq9g66RpXj=4wW(!@JVSwwUO zR7o>T!WEp7q`4$$Gl*oz!-hzC^$gyHT93r2f6e13GdeAt9@i>L8cukCO{O4n(ky3Cvw^7FG%`TqZlS(guj^eMx*u$SY`a_U zK2QuE)Pn~}ydAp_G&_5Ds2Dq{$Bq`nqjsf(#o&k@9HC0b?#Bk!U%z{}7#q`LV+C>S zexz@G_-?ouIig376#PdRUZydeeObW&==*r|A<)n5YM1?r=m>{$0C%M6v1?kqAx!&Z z@p|o4jS`lF*DDM|Qz(DOHx&Wshg9NJ%lpS&u*=;*TuLOEw1LI~$)yC}N|E3_?ZQ|_ zE9m|a<9|M{Qsx8QSr$2jek(0$xbX^%!8#n1iA+bB#z^1MCR612RMk`v@p3nCH- zypUN^MJ=-opQsfp0`5ruAGQ`XgPajbCezuJromiHn-6ktaz<*t+~+`@7eH-4*KPut zJ*W%952ZL5eed-3)Az*I4Uug6f2}C)&_!~mlv*XYxX!KOkH2=bvoFfCJ*?)vVvZCz zuI5px((KL9YFXNHmL>-8!vK#Q&(g7i*vT1hR6R(&j!^OpC1lXxxLWfnXSZsK9?@>7 zqezT8bHma~Hk(XhUO7Bv8Y1VD47yk{fKfFSrPvzlujGHe;D5d>ih;iM=WoA2W`e#_ zTaVthw;Vv0-B(R&P1|r6VOE#RvdbT^r%u9F$0m`+`1bXG2EAmAw|rpSboml-F@rCj z0>6SxQu81fO_~7*0l?3GMFd;%Z^fuhWV5!2x$o7tFl0Qal6jir6O?p-MIM$ES$T(> zrf6Coyt+!Y8B|bb6&Tg;z-k2@w^8a@f|~&K*QZ=eXu;^ftny1@EByQ{>;87_Vsfs;^3;x3u z*wemUZ{PQ8z#f{!{{`$(Ff1`A3fO?j``aLoNNqg0Ljokjog2th;*Q(ixKd$ptnMpz zV-NBK-t0o&5mhms6a8&smCdJ#g>kH^;aEk>3|SjGQ`Jyb=9Ad8SU1{Q~A-7Mj8g z9Pxnq9-6A}Q~NrT_3(0!>juUFvLEmg+jgK5(qEvLdJ>sc;pdT-Tc^I>e(>vO&K7#k z6(f^+WRgUN`2A>Wp>20DI;2O33gS?yx%1*Qtf`C|k6bYjFyNeAVniNNY2CAXgN-Ts_pV6D9Qsylum5+A%H#vv$ah z?s>5iQEoy>r`+rg0Y*u)yQb4K&VkU`lrGfhcDHcV@XM{Jv5o02Jpi(H^*W}vd~#gk z&OF4CLUw1iLQ`CITZe|Bbc2LY^eO4x8F(!n6O)#3U`zb~y;k?k-$qF#<9I5qCg5l_ zD*;F7Lf&B;K8J05>GOu)R5A`0qiJI@&S)HSmNCy6KR%%?AFxE~&C{#iVjeTn zYB2;N>kx8z_Qvso*u^G1iEv^*fEg}s)Bl5^Nb z{|98SJm2dW3NAF)7OVufVpVx+ZC=xXO>E2MBCEo-f(|OJyZW*EBQv0LEWyk4u7&b8 z9Ynp5x0JKCzE?S4lQTRHgZqlz_A?+D<8=y#WCA=_g#S9-k~+Z{n)+LmBU4knhNBf3 zjn&_wfnt9DEq?*N(u|UkgK4AUFoYMqNjtW3*U%RIfc+2pugo z_3BN7_nHoDG#$9x`p4bHrZK%~Z1v3jSm#=IF*cyb1`6W96A+PJ5ydW&L1*$5^F&wT zp3z~=8sR4tCJl%KAtKj$%n0Z#>IpW{BkS@rYO=sxbI{mCers zw>~%AdZANo1a$e)gk38>C;t|E%7fw6-$OA(8U`lUk!R2Vs?F%@@8f|+wXHX6pX#Bt z;LJ}dyS(@57S1ZaM{hcg{t#bU({yX#>&}D4ni0K*Oq0REQeD%{DTKp~Zq)56gm$?$ zCF(@4dbWmWJ=>a}B15Q+B_qeLbqu!o`V5c0upeLXsqm3sAVjpEZJ}mWNTt{OZkd(r z{qWLbE{>W-NV@!sm+2rNMJF+8MFx4$KgRkIiUbS085lw5Q&KoCc!>!Q%q))673Rqo zz#N(FSrkPIlBi6kML-^P#eGIYi1}hJN5?!5H>w@-%gyub-ZaFU5hWE)M;ek&R+4~1 z=Qu6@0c(kq3MxCIK!-WW3(EVGb8zKT7#KT{$^D{}pNlGEUBlQ_PGktSJ`ZrP`uoTm zwP~`Zn+)A6IX5uLnU_ZR??ZdlF`CiroGL>wZM%w|QXt*u7YMBNq2YY;#7wy{dVY zR0bpJs6QCI5gDqP^U4TGawY(8I*+AJY5%+ejd$S$`$}xq|;(S&Rm{*QWICASTFt zHSjMG6pIXmwsHX9?51kcX{yDYqyszlzChD@o|gczfThlqxlsZCwIK{9SkM zbC^(ABQAv%CpN1nthG?)a6Yq|DrzdMjjV>tk>yeMB2f35pwhNOajnBlNK9szQwxe~ zU3G;LN^HyJR;vYmXrax$7FQ3X#zvvW2BF5*D4m!YoynvhOmNMMm5BH$1^W4f7OaWb zG`mpJ=C0}V43(A?)*jT@?r!0#;g{P`qnD|)JHWR(RI-jN|Jk@Fkjpx>&w^!B5^*Iu z7|cBgT=;Z(=@>(DVtv^$eqt?%n*akdW4~A!#Za6)xF(YW(_s-d%$1DLVrYuwg=V42 za|o2q(FuI^%L05lSwtbGSR@dyHX6?C-G;pV1@a3Y7G%7r{+H~Yk{QwK3*_KsMgQRGeM+coS&Au(x# zD+BcQ@~N2_ya9Hqp6j1g1QOl}7e-t|@~BdDf*)i6WaP}>eqd4JW<)kzzU&|@7cUUY z24^Z0sLa3v>()JG2aK#%d}6!_&szcxUkHA9f_|ogg^OW2B-B7Gqked8>&GZa)1iwh zSo(k!`mX*BV9w-IST;-mj3SuSTnm#2EYrq6MmOp|q2vQfa!Bm6lx}bEc1JWiY0GW4 zRpSqcUH$-V$-3U-@dSSxyGFfb*J$4XgU0#NuV&X!e?lv&6nSWw%iF51hVTBLU@a;M z`kcXRGGmZPM*ojUUnM;b2Ugr-J(_R6yy`%8(cfEVQ=qJG#kWr5Zo0hf$auz7f)O*$ zj{_F7Y%q`mQ6n8dR`5e7b2z~Y6Gs|&hcJ+YCm4#xBAe(V;zOSZGg##t;x)|TC;pp3 zGZ>GDQTv^+JGT~xLj_s%&JPugR^4%Ky5?3F7LDdQZD6>Ku?;D-IDsAIHMCCBPsDiiUGGI>PCFs1W{c;cR}(X?qc%Y#oh>wiTvJi%u8fX$Gr z14lQyjuyL)>0QU}b$xH6>wCqn3B79q(=OJX)ay>Jo`RGXjD7gZ`>(urj>02Kv8HMf z61E<&j1fLWl;yVtx8KAOf~NtwYrY$P1kxaud&O4?egsmJ9{3P38wMW|GGr!Z5E4c3 zCr+w3Odrq^^FT|$ffl0%nkWTMnR=N?a45Wf`Gu-mPcPTB(m9k}s|Gdyk=_ZZ6Cg>2 z*gpPQJ9y*tt#KT*wjOZM2{6!6{{*$we~AQ!3IY+M)*Om~ruwh($egGd?m?P~TYYj- z#_ALO_cW<-w8ErH_02c4_v#O9)E_9;kLdLytEZtVaz(M&vR7}}yLz@{%@Y$|uSHUf z?bGp0Bq~xPhRVPqIJEA zalq1}LQA*ae7wZg_q#jSgd!N&LpZlt&*!`m3{UR1w43F#Kaf;|Y;iHh= z6Ta0~9yGUqvj3<1*Y_5i6MA!E^%YoDYrEFA6=S>f*sg-Oi|vf(brMuEhz%jbO{L6$ zSk+d_RQ%=PCoZPKBwE4%>N@&nL#w}!oX6g0Q(Q7K5WQXTEyB;l3>$n7e-8z2=D|!1 zTd)fYdTd+ff1#ZWir`r~Oyjk<+M-pMR)n`hVA$rhZ8)?WdJu#sc?4FniJ#?|zd3?dNEx z{zK}8NCr2RNn4uA)IXw1L^8P_MWY-!Y+QjIAJ8o`l+__;h_B)l3a1*xp45Lt3Bj>? zml9$I>XtULBUJh{hBDIInezU3=*<<7fTLM|_^|;xXirIPhixoFu7pGLFhdn~9=bd8 z$GO7JlX~~5asUMnU_1-ryDV&#j6A`3778@4citKJGOqXGj6Loulg2nsAJV&?C8OS;$Gk?BZ1Wo3c+6}30XMHvp1N}-olh{Y!76B) z*N`Lg8onx%P|z{2v8>y?Mh_?~ry{NY?&dZ2o)6$<$M>-KSR;fZ)G&IeM4mZd5{2S+ zRR&uO63Jg_)Omhqhr=Lob#`Wj8-BS}m-}0%@!M7EbJVcMd60xHCIx?f3X%D_d@gb2 z0RIh%5dR4!jtk7t`#kl0m=XrWFHw%-y*LW}DdiaUyiU0pN>Y?uq9jd;Ovx%G3MC9j zZcy$zC2S*@+?W(rr`aQL*qUltlHJn8{dD( zN*=q*U-ApP_m&bPrSX%cW8>wT9X=d`MzRJ=GQJ3(`fm^ET?lo??SuY3cSiNS$MpW= zCD4Q7hGQ)BYn!D+jF!bhKZQBPSts$`*D~E%=-Rs?YAASb~ zIs7clr%UAk<@PDs8i?Z87;_;`yd Hj930|zzA8w literal 0 HcmV?d00001 diff --git a/struct2tensor/expression_impl/__pycache__/map_values.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/map_values.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1981828fa104111448a7d1934909f4de0fa4245b GIT binary patch literal 8911 zcmd^E-ES0EmcRAU-)*<;Zrij8P&S0P!Pv&&Bzz`-2_yt3GoWPByQ8AgRc*J^Uwf+@ zu+wP02yH-$W*I3F)^0RvjSxd7d9YS_U{;Kjhxr5As7AC(vSewr5A)UtRvJZG?VfY1 zySlnF5X&c;;&VnV#t*>JYDs|lXrPSZWEOx ziOQ*bmdo);J|`rFoR}1Io}?$|O?o-(6SBUXKj~*>5z2vNfR#N^4km-F?1geD8DeE0 zl*7p|EBmt@$xbz(1?D0R0o0%tnrlcw)R5+ZdUvBY-Y5_?tnq4x);ZVHXiD?0-cIO= zHhR>En}^z^dEIYxzt7Dh-Kln87n89z`Fh;rQ3|&VGn~8T=BBnv^t>_z)hNi&`@bSX z`>wc7Astp@@ZLVPR}0VeJNwuLwJnWWT2%XBOq?fYh}wUhs9W5-dRHLiYxvqvazOQ8 zAc?rLizDQ`Qq(15R+Bzfvh$iQ71B_mnZmp-Wi{nvy9N7oC8tTL*-TcIjDn;{dL}=U z)g-Fv^I2mw7(9{9N|}6d-q5AoylzOdFqtuC6ZKi8s7b>`g=%?2&g50?iaqQ@P162u zUdbka#YK2cVcq~9Qbw0%G9PRC4IJwIAZB_#pE5Fqyu%0D7%LDwKV!@y$LWHBLMR5~ zQz&Xw!6~`~a!4?B|Gp{7(DHhL>e8?Rt1lJ{y2e(hr#AA+_!|s@}E+&S>?m5 zW(D4$g+(oYK{G7>J4IxvWG(*%c)gZSX_n{Wd@-vfIIA0&q~^1Vp~(d{OozsxcBKg0 zro$|#bvA3VqN>m$78Q`5vwScLMVq!nS|}J+sJ?sHZFT5{d79GXbVkdnx)pHNHfniP zW1*;7oq6rD-7V+g>?|+UjCq<*dFyI`{{rwuy?FA^Kt3g(a<|D%t_l<)l^FUi2TL5*rtYx$IYv_!A7zRlap44HfKr=HL&?ayNxCv}}BC@@4mYXF!Q-)$rEZFZs{*ULaqA zE?WIFnjx#2p3?HF0&rHKvPNlvS^?}}7?}_)estX8(zFx!(+(iI2t<;WmcCDp%@%Un zSW(T*Q>rYCO=m(zZT22(3f#)vp2h z3$ADn>5M<>7<$+-R6f4iF=2L0EDf&(`pv*#)f=n^d)7LKnhS0#t#2K!d0=^uaa*3U zz44Nst;?q$`-I3&Xo&2rdAuF|$AT{ycuIiOI!P#6i}-@k8u1BHx(%fa)lC=IkU|}3 zJP!E&7FuqT5~oKDmtv}I>UfQF18xb%mmQdFu*Cmu5X;RuFs-_rvKg+#mjnZ4s6TU_ zMyZv2Qu6tNA%O-nbi`|8^w0kVPr_>jnUS1ENMR~%1rb;5JHraHOF>2}ry)2K zD~g>;zK~zYf!Ai^FPI=g4n<>V8dM#D8s-y2^RbbmIiK!F@+kpQBfK!lj3edf9K6gAU*?Q_4u$_4djNN&McY-O8ODI- zi^}=`=bDe;`=j0Rc$L##aM<)kB#369xVy28NKRiuGK{UB&D5FQ4?1cPN>#rTf=#8S zBhZzEua4xeRpn~@CDh;qYH(sbvF9mNMo^PSY?v4G1yyTtOdNdwS9twx;>Pw8FzCU&@!r1>$z5Rj%y5L4dFQ7>ES*$3V-kKPA%Y#5{lP8kgXw7q{V}Fc|yfs)TU$ zwx*<~q3I$EOcFjTn1MJ@Hx!6^Xc%5>FA2>@z+45SY5Ep~%v?HGfYfI`Jra z=wbBG?|1%Q{qErJ<{r%bZtTI>YV<8L`qrxdoasMT;m_H)1Kgaav3E}vL7GqDq%@M` zAlU*zOHxpIP1HPd&i&WK169y`s;K$roIA87#o&P+mfFGhq`eRXajN%vI2lrX&*BJS zk{^QjaDb@2ka`idjyd!m+2eX|=vnh20TOYgYdW-U;NYZWTpWCE&gr1qdEJwYf?P3> zAi`1{UG6or9+4$r`qUn+znOM`o^JQZ<~txEfyt6Mj2LjwaE;(uT4U=Ca)nF20O;CB zcfe8?zeUdi9RmAmYQztAI-u$NGEZSMj}p}tqCxP@?LZ!4Hl-00A~QQSL+n8h#c>b7 zS1HUxa>I_2VT$d1gCvkW0G23`oyy280!MfUWIF0TizHcIn=QL+N16ND;_IQ+vb*iy zpzsBUTIR-syU^MUi=P#qs9baRnwU`%T2X^*K@6sMJOtldd+^v0xyF^an?!aH zro=Z`_iMp0e>$r0eIWe!Fz5tcQ{hOWqs0^8)W%G|Wu{kiP~!UvX^dZanWp zzqIbBH_cXG+`H*|Hy`|uz*b@!mT}di^3d+t(C&5H>>O?ldrBU+pV+ia9=8kM zW{>jz0_%=-wI;sMyPR&>c{Ds`R_MmmjjH=9=N6$oV*w`BMfzLFXaERGKWTr;gSUS2`OloHZ3VbhV9e;y&~YGE zxSqXZiEH{gYm(iF<+CmS4V-ZX$yp?C;=o3x)=sB8ReNzdCN%&M@UfE`0p|b~xBu_~ z-Gr(7&w+qz8z#NSVeePo2eKMFX2y;!ov*SrmaoF%R%4T9Y;x)RfA~o_YK8{O`&Kw} zXu=$Ntva;7yz9=`%CXAOiR#d9a|q+KFve+NjMF^zIL&>KSY|ea#)LKmy30?$!H}Fz zp-?!YzMTHXavZs!u5I)(a>cDe9O$}X3uTE2D=BuFKo>q(W}+cx zAX=5UdAPRdDX7{v4R0UR`_SOlsb;1g-^_kSSls2`sP-!nQo-09Tsn-Ew3HJSe~hi5RXv2 ziYW#eb{GS{95tdGig^jkXaCHjQ-o#9_pbfZ484do@xwwPOW(uEA0T0A8S#Rl73tq% zT|^=Su{!FmKni|1kY%_#Pe3a(m)C9Q20dWr24BIBUxK2*wcq%w@uJ|Z5hP)cAT9r_ zJZJ7cVs1NHBT%gl43!Vu_Lz~qH4ip=N$B}**dymZ|NM1Q^zI?XHk+HSOc24Ln8z?|`JMp5RZ<)IeG4bxJ?NTzyy*$K& zL%Z)JY66zVU%fwI9-21C-(htO*CD~}5I-mMm*e*iesjnigvCMaIf$*1cu;_-Yz0Ki1I%~l^O-_1&15NZrY?1@vty2Qws4th$U?>0;=7PJrdh0~ zk<1_&M6w6T2$ESOcvaaEGgBk{QvKI>#}rS(%KyDF`pDjBY{eyU`9rS(%KgO%1#mGoB{pPF|Y7p(?kW^hMM zz~bn{HyKu1k8WL_`tnUPy1V9qj`aga?w|Uu=MVAw)8>J5tO=Ae-1*t7H6L_?TK2|2 zf3p^_%R7hf1UMqmt)a>Ih4>2dcKS=hsr;_^Fd|H zs2Lfnd7#~SrCpA1U!F6!A6U8g?FW_Z$IUIT)qL1NTFH1XH*zZp&OCOE+>bqovjW=m z$vbbBkF41CSs-=yh~@&Et; literal 0 HcmV?d00001 diff --git a/struct2tensor/expression_impl/__pycache__/parquet.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/parquet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd2057f7ccc844f887d1eea11c1deb2b01f8fca4 GIT binary patch literal 28931 zcmeHwdu&`+df%P*oA38S8d9W0N-`x%vLs8?gAz&0mQ86b?TcRT?1*zI4n@w;cZQP1 zCAYI%WlcGBqHMKTZmS}b#CEu~(=ve*5b3t{ANeCFns$b`EoUz(ltpT^2~c1R+o`cY zivGTH?sJB;DIrrRi&-;AubNKtAkYB(xRs7qDwdV!lFX=(Mf@*Ar=xFr`om031_V0o^x!MPl&a}Dv`2!YPa^z^i7rd zo3b3M#c-@fbc*_8v~*KldZ=3u?bSl>I0WHc z{O8x$K9)xV@@N$ImAvbTHHrSIUM=m^el3LdI>k2pcZiM&yU`{~*cod^FC1X`w;=yk zv8N>egE9Lxp|?HyT@3!@nj-0Ath%bvpkc?M0{H8A&`j9h^bzD7E6xD zW06@YIVFy#={sbiKJ}@;J%~LRet9;Up6t5{(DQQVP;&gaSm!IfuTU1Vuf|gJm0Ctj zt>}1kdVFr0YNNfm9i5&NQ#~)2mHEm+s{TRb^<^);EeFFlqUrI;NGd)r_WF_2L{dUS z#6&77rNZ$9tK}(H0*a5YHnW$ky*$~=hhJfx>HjMl*3n+K;uuWaRh&cd@wDQ6A)ZPr z-Yc_d^lEfkab2IoASkY{Ch**AQ+#@R@ldTlM3U@-6e9h`sNl&-{fS?XT1f;OW}>qg z3Ram&I;jQmCNMi49Tz8)(=k!PW9>~bq4%oV--=J~3&kyo={YGej#;E%9)xI;e+*0^ zyzcyokT$)Y(!&ctBgXW72l4Of@22cpZi}`B+mya!7VQi6tW$eu{(T1z@9Gh6IMV_B zL6~xD;g~&Lsf7h$$vFW;ch7s@Ipx)2hXiX*)_T3?drF+CYghf|Zux#?K=1}$_k7LKOEX`llMDJl18!c;{z zt#3-v*~vZxCF+Rlmce%+hPveAQvg4<{Y1DaEC{bOF52$fvxq6b7Hrn8`Yro%(>38| z0w!N*A}&tHA~)_v2#NGM6no!c#Ytf$7?q^tok$9JYbKha`4J9he);u7lgSzJ&}?jG zPLiT`4-JW_Tj}KNp*usl(7Wd-s?A0qX1V zFjX-xcmJb||0|)e1 zA5&lb%&r9DsfZ-b61I&g-to!!bPNl(H=x*OUsVEE`CJ*nR8hP$VmeBTRB`i;Q7W;v z%{`?!W@cxVklF$@QD;Pb*7D|n%8kcDCC@qm4=Tomsszws2lhRda*jOzm zKr>1mGKp4xiY5!;;qy`peX!>VdR97!lzsH^G61^r(JUdV@fWx9`K zdJ46z56@-#3e~$es}F5dAIewv%hmmvgCADa%2geOnuf=f{RM3p(COvXZY*jDXJ`o5OU5@He+eZSyPZ-P`_%tZ^>t)}ewLsoCAo~V#_5rDl+7K{iQA+W$ z79&RS8#F`lsEB%;ny+0t()<=;tSECqSg{;7nxdW#zv;-a5@>*=@TIgpbVM= zNOCHCUBfXuAVYi$bfjUU2@#kn2J*D!y zb!<)<^$x31+aXpBQ|lcR4g*cKN7Pk0#2J*Y$J3L;I?lx!jLs@G5xrw{>g$%C!r)0X zyZJIYM4`g~py6PBQbfn7Rp>^xV46&Ms5}wpe5B(@`7HjYjsjqjbP4`_+3&7hT8|g@ z^{<`4^!5*6cKZi1U0c4;Z#KQr^q@COB0%f9_N`+hc=$MLOx>EZ%-bm4VOQBi=P zAdJ~#wlPP{4u;faQ1h6haM<>@~gjh~A0dk^oBWy%!ww#eAY4 zY`xdQ?%HF1u)Bd+Ky=}K&|Fe1Xi4deg^)4?Awd|ch*bhsidC_ydnHtVtQxV^WwAAg zttpGGT@qfijn!dk)nZs*ir%q~e)MTj$-%HXMZ$L`Nj#FCg0g2hX$K*Co^d3nZPfD>=FU|Df z+oE(gJS|2i!VJZ$kjlZj)E5WXSrNhk&o4X))*y~*CGRA{e4?qfLWWdrG(k+H!2@HV z(#qn@Zn%fHSxk&4V-$@7I31poqSHtX26`q+%r0t%oCq;s!iN;B~UaLJ9kYIP&1uJ{_5-1H>Je-z@8 ztU6zdi+93@N_d6jY)az~BI?M3|Lz+fM~Ng4M8Qp^QC4^knMiCf^EuB)LeC3Biph+u*kYsXV#zd*VY5)pn5`mOZWvd7(Kz9yHE%w!r__@8 z)64N!XiiYp*V&4im`>hd_#{lzz+5XN)$_ljG4;s_t&H;N@@1QWO6Ucs4cMlij47z+ zU()4l%~6N5e&HAh{e2Va2Wls;s9utmx-Wcv5iG|H0XB7BexCljU)-~x0wEr}vSJIEAk zo}se7@N*26;&aMkLc9ZHfZ`G+mPBu#1R;h`M-A~Fq^1Fpp`qn3c?Yn$Ny4{?`N^_+ zDH>0Ss?Zk>e~ok?!w{3CQ{gyA>zv+Tf&vRCAl-?tkpu~ZsrETj|8Q$K0wg-t#^ zFJN%c3ChTXqR@3I-gP{`(O#e8XSyLKPbk3%PfQaALPnY?IznosisWiZ{{XUL!Eyn% z3VDA9h>yPtmB3XL7~Tc`bg==6R6*?Z4!I_ox^b$Mq4lIFCV@ncdAN*E`ZTe}GB z5+43|9^iF{RcEC8(q!O`Y6 zz*exj=n#C>%axnnT^rtAD|2~ow@ml`)#0MUl9tVQ=o7-oN$Mcb381*7XyT@*1R%(% zHB0q6B@%U%kY(_dI!1+1W3>UMPN2z+`J$>FIwD`x%=u%#O6QP*B6SS_9r0G-3&C*C zxtl8+g20S~vm=CxM=h9vRx`s4e}n+=0-$yoc)=2@;}uIR<-smXtd1Qlu{zghiPgEk znC+1}Yr-;G&fqRf<~U)L`M-EEI;+VLI>*Ck4#=>fBqi>fVzIDpOV#C2E+VB`(oeYa_lo7Y7w1!3X@zb`>Mk1QoyM!j&s83=fVfuIEOt zA3ef;jvrH8BQIPTyso&1u6*^v3&V>4!Z)uEM+RSdY4DpMUH0^Z;=GYe!U_P?LLcD= zA9+xRC_TGl$+;WQ{keJL6bI&VubnLjiRNC4hE>xSFjC4VOfr(5VAE!P?{;%r?>ar* zZ&POH0oY`&Yj}9#adqv(CLOdquBm(2`?#U$VZ0df1hMJ?{zrx2@&90a>0sX7D7zbr zl^$0US1ssL267~VOY`hw31CMWCKyy}K`=}L2Ros&Z#YyJ~3=+6N;39z`fS=je5TvM?s0`#k>XMC zGJGPACWy$^!nWU(sjPN+HOAhoEyuszlLS2(w*IlS@mh1NwT6fE>M+>!$a_t@t ziy?Qg`y&D169PrS?dmS?d8!`kdFl?*)kE`c8)XDVnxIlm_4fevlTu>7WxGoAcYjLX zP*-P!LLiiRuDnDlh?Q7_iAg0cg8Kj$AsBSfY!+H{EZDZEi)nq-VR)e&*=6NcY(mkVJTW}icDd=GTPW6Md?6}2hFYVhYCzP$v)&-Nr z*IH|vuwakb=R<;k-gGTGmcO-AeoxkOJAM76wK2>3rT85>0bSlletH;d(J{)5N`FNY zuh+*`mEs2XB_?9h%ZQV{O<)v2aRBG?wRfH#LtJSV;bKFtn==Itpn9c3WdX*Ma|sA& z(56b_ZN(N>oZ~PQm8n&OYEDE8QuDjFneOHU{{n@l{u&@7Z25zki-r1z%xED{DF^ms z2j#%t%uvDUdHwkxJpZP&{94}GBs-fx$RCGl9`yfse5FfnIg)QUDmNU>hmOgiW3*7~ z7+A6uLe+ApI~$clJ(*a+?R$Od2U9<;dDxV9x5)06oV%sq3q7cNqiw0J;0^qy{|*0x zV{&bGHX+wOo%bG>y~lIj;{|WSN=M$?x>Ql9tX*<`wq0G+xTUmQ*^%v&EBo`#BeL^I z&Uu7Hm()*N9<^}6g&NhKbGDZ+Yd_l0mUR|^oly8q!8&CpY-SY`gb=|jgkZUv7B8lW zj@qrclrGnIPy@pfY)wmI5}RZe?wl|Q>B&Dn4{%wqp#k8|n-^`{8?Ix*%@Tg)Yr;=# zHetCXhRmfnG64mq$*Kgn!sgm!bXNSCopIw*3NOr5`Io){7}?2ca(4dc&T}nAd8AAX zxAsw<@_QEGzu_OT+_v7!HVzAWy56)eSFAW6+|GM<$=+Q>hw;*!&>UuiE0IKtHN=+O zYY0o{3DCNe5(JV2N=bya1bJQO!(Y^ze}_y`Hvlrihrx>F(Ul|l;4V41D;EgA{j|AG zk9(e4js9`%+I9KqXY>7+oOVc01*zdu5ewpl~38v z;X2q<)M<(!n5QWwqW;D@xvWp6lCt$J<2o#pkoYt<3CaAUPr@opLOg4i!MdwcuDWc> zx%rfn?&2HsG?QrVZ!5^C$c(8!ww+=BfwJlX;4>^&brJW%@n%$PtGFZb=?KiIPxll#u+_YcbZ2lHJQNm9M%e0|qY2 zfr~lkMP9Lb(hQOEDH4GuDl(Id%|R&;fhH>Q>RfbMedCEl$ag0q5zbK@;H~+3b`4zz zO;8+Q!kJP?a#AZ_B0$qg8Ygg*0I>^-7k*7pMy2mca}@Vo0)LNQgj8+QxHvr>iGW$W zgjmLGahX&)Nl6J1xl4T&)DS|X_OECUIKW>MXxZgx%sR8z-trX%1lOJG*M81}h4zkY z=UY9GzFlY-C4(@XNoSwZykA>{j#sO1?O$mj_?Q053c?Bo)5n9 zXB9;OkELL5m7{5OyeQyjOOdE2{*G?x?r*N#N4n-mfAx-WaSinhR0CMV$n1h%n5hQ$84LhdY0w_XU72E1>wLL znyS!hKUJxPn5$qF92uQ(gRIeFO?jJ`Nh3Dq5v$DU#oDPl?OTg2)y-Tj>yh4)0#;S$ zSOeKq#~Q_^ZLFxx%wF086j-BFT(#ISqrE_bPhcxGq->@rEb&_4YI&v*xK73dRZVZ0 z4Wi+zJdRfLwOMieplU*dGLieBz&;2cNEl)hP%WBt)GX7{WLi8GehzAL!yI$F9|{6| zu1beoMNc|*?oAO!jn!(wh9Rm3UX}_4K$TBwQ)bCfO@pvxlq17^urV>qaJXk1%P>$cz8Yc9l-2?0>FetwPfMsf&^3`1`i9&V zDQ%xPEzXc0o+5b?)xN7Gn2FvJBQTIe;7|s2Ju`(QrJmkC(yxKTN%!!qwE;O}6w=3> zM@bWj?-S`%5ARy$I2sPaszDy6ebQW_NBx*`GO6@Qhd)lpSi5xU{zq8CUu5ZoZ;cQk z0?e)j8)UvFlKHX@EzbN4d#ByXbb|M5Mrw2U`UeY;Z;i!O`lZ=cN}h6tt&~0KS_|Y( zEh@XL%r3syP#5yb5>sAwfq&7{z<=9j^$Ui?Oe$mbN;hdC=;#D%o3oHK%P+Dg8%&Cs zX-C+QVomxOxJn}4(8s{aOBMx3)QJK4f6mHl7^+|y{r@(|C$T{?eZXjMg-PKygo&qA zkeM(Jd@C|z4_t3dYJ9E)DACL*b(*K6ibNPHQ>t~u#+*avBvCo4CoD5PMnG&-BM8G; zI-Zne|8KxcV7b@)q3wO?UH#yLqK2dpYkuEV~cq*uB-UFEc@ZYd5n@!a~{1ERdPH`I_;Ph@f2R57oP)0(s%-XX3 z@G5idh3MtlONHD*6BkB3$3gCN^-?Mlqze+Vxi;lD@7n6p$ zM8Y>@zC|Z=maYXmXv9|*-3v}b`<&Gk6y|b4Tigczw@QJFo&}GA6&Jk=t{v zuQfmG6R6Tl3ohVNTXY_sqM9?pQQdFc#J*@&6!@aSdx<$^aGlvUW)+K>?f-_kmMs`F zSEhTJJb^Qg+Mp8Sj^RxIM4iI@P?-NcaL`S-@ZEurCu53A*hQ)<^q5RU3}apxrc7q9 zl`bPr3C&L4O~v8N$t*I`45b?fz=x#Rk=UGnhC*K<&`IDF0rs^6BgJh!MA&)f0CI%# z8;F|kC`Bk`m@4LjPMM+(oEk`z5 zj;syjTh7ZZ=kxAC**%zJcOlT2b2joN|0C4Pve;>%vBloT7W?*Uy3(a?Ebc`|Tv&9P zb&U(3xU7p0sKMm#iiPer)ZAboJp;&LxzD_N1gs%i^Dh4L z3kYV{x4;7}x(H1yIB(me`ZSsP)hlLO@Q_XiN`FA>w?D8g?~XaJ^qo;BGQ2n`j^Bb; zq3UiW!DK0>u$>^x{27=R5Zk##LdYx^$e7#2Hj7^)U)-JAoS2B3<~3xJn0}SD`FVa! zB>KS#L-#w{)`y8|31P|X$~5m{-_5t}@F{)LH`*JJ!st!YDjlWAD+KlspcTiQE2M`A zDLxU72;}l5D*mtO4p*=lvUN)nmm3yo$ttV{DW(9HeoPUzS+lc|x?H&j zx-h1xe~NU5W5Hy;a-U51;moD2U`1w_3=Ajb+I{&@59|lDJohjFvx2civ;FyCw;b%o z;tW(QpUgxH)pc@pXGUBa{QhKS^24g;l>_;zy>b;Oy503GDP92YJC|%r{RN-@LHE+G zrCpCJnpPUK{XcEZ?HkTF59cdJQe+#l^h+luGVMLfjM$Ov$3MiBhsk_*M0wOLE3xYzhH8h-%-*;jxootnM#=`X%ndY z2oTCP>2BJn42EdFj!ksBboxfXYzMs`NRzLSdKq?d16^6=KFNPSU`UW}1Y*ubcfuZX zExMm5S6wD8$@OTuL8}Aoxq6MJbScwNa4mY$yG&`TwD6*L$-fkU;>{6rV^g2MR0F^0 zOPu(^QmXabtflDcvuv%F!u)$bV4W+AUT}#%^t!G(Gt1RR``{n#S@16S-ofa+tB=qG zNx+`yXN0;{KW9SCk*ij%730;kO@dyoF zmn=m}r|m6dE}3v^gNyt@7Wpy=$LuS;q%jlW*vj6h?#U|u6Sn&QF|`+6O`*8) zDD5D+0}pT7J;b&HaQ5hyh%C}lbo>_i?C|XaoZsbIyt3-Tb~9lWL<^Wd)ck|6WWw|O$gQohKb8?m0?bWPU<6UKZXH4!Ul}o zTd-|h@w2(cIa#8D>Mp-bMip^e5jTIbl)iDbYs6BN$=IpyqNJ`!4IvAXfdxm6*8p1; zH6S^G%1wXQhQBL2w&s+(pU$1Q^yhW*iP60OitN7vA4BsOj3D*aD*=rCG&b5m$Cy$A zS0rrN5Mx7djpVjY;>Q&y3HpjZo{A@EJIA<){g+YrAjEGd-WM>DNtLYlIpf2VUSsM6 zikAqBDOIH-4X5K#?1*7sx-MRN?u8-4#R0p(hX+UC-bHy~Cm}kQ1tn6+^#(bA@sALe z9ufG51YQ8Z?r?2HB*K!4FM%T1Xech8!){9ZI|P0gAVs+LJG#_x{ydi(ejBm#T~9Jf zr)We7`+fEq#^lqzLUWhgbSU4{FE{n)ob?}i{7dKZo>tk@y7KCVC!F(y-@dlGd%arj zJ-t3C_nysnpOd@KebB$rJ(TMnDpbPqWxw)Kwn&h^fpyVff5!EIawd!58J=i!hT1uaHK?b>;I6;j5h4 z^1bfgiP^eua>6sSI#%LmQkKx}eVwea-TT`rNlKVFU(9|_Kv}w^RYuAk7WX2DIW5Av zs@$sFhRkOo-V09b{TIr1^t&uMyM9CX78ub*lwhdl5!OpEOJq|4EE;x|w8x(|@Ne1K z0Fd*(#^gLVQwQk`kGX%L8UjNDw($U1I@9jRY=u^$`IndqgK>(vV(u?sAP2wb3&u)% zIV+ha{Gnb`I`b8CZ_kuX=?vAB&O)X4YTtv^(;1{CAJ{iy-aOkeZ&n5g?2&g~H1kH9 zO^MKDn3=6K1D;t?w7v2t%Fj9|Mb9cB#-b1`1ej(pBf!TIP|0t@PR>H5S_qKXMvcVU znIH-3!+1O<25C-0gwHMB8;ok6*a07#)!Ew-%>_cydypkg9taNbJNuP4ugjtKm1sT` zmP27Zp3y{0dujce+H2OfSkLe<3_n(UfYGvZw?6nqC6JrOBM?=`IS;niOn#=n^bSl% zXKuuzXXl^(lFj9xeUBPTvnnINW9TbuufBi%{YbulSneOr2%O^E6vr*`E@QDqElt&V zE2^c=TlIeqZ)MM-xbmxX*sVNR8P|2hm=Z*IB_ix9qa|f2g762Y**mgSzv|>zrvCaC z@(1+>k-K2?TcBI0t}nWTQ1xbT??!NMp{i}Os&k{NGkYvw)gxE+a;UR-A^R!xQ%dEgy(>6MBgAuOTUk&@iV`q3DoY`~ z!&8LRb^b_WkA4%8^Lw6PekQ4Re+a<2qYAzaB=#t|8kKh)&R6!s-BPYRmAO=?Z+{qp z<7jQiW^MmQZGXP@s9bwA^Fq-pgw9}fu6D`$!KDo3g9CDKfb=_659(K(a#Qc>G5Nrm zeCVtkI-41JT-UO49Eupa>qhc*BXZqH=EcWHkFSSvM@RC`!CKyR#a_AMuw3!Qls(MR8*U3b7@gd7V2G9bQNM%m2E`OkZQ!or%d@JveGSdH zs*1SKoMOjhA1Knc5-xj50xy?+`Q|#3(V46Q?+}&DK8Fpq5)=~HJ2AHmRP~YOp=S*xCt2 zEmeaKJYuagr6hF$M+ew)J3AT0=6y29OyA{-iuB|hYXGZBc_Zjp1gaAA-@yqB*y#(c zf~pbF`qEU|Hu@RrIVN6{y8J479kt=RyNQ>v=xMm-o%aXXyGq-_IuGrMC%{*b!wys* z%%GQ!W25vPWBW3=4(!;4st3E$4~FARQNngrN1-pd!|JaU%tnQq`j~1&Jw^rYL|h&J zMP-I>CTTZkRM(wQtDvA~;=cAVx-fyw*ha0PVWP9A;`&A{t~fDzgi~;At*W#t4pMwK z4xt6?%`VnL#m#gV&;h7A47MNHNd>Z9Evl5vC1lEO{=!Zq zpls#r=r`~qq7YeR2J+h}R5Z$=y_u1%M%dN7u4)FB0)tKbeSz54wrh>7G4aCXCe z7Do98t_O1<+MZpY*&Ybp09!^n^#h2k)mk;mIYZQ(<4zI-$mnk1PR$YCCOyErGEecK z?sI{@oo_AaMf1ylM%_a+7~3mo|3OF2*}=JqRw+YSf1kh)2>g8l#GOm`2>g)12mzAT zIA8E<6#5MUZxZ-@0;>Sng-04v$wLvO5mB7ipx-ia7>q9?vM#Md46yj3fR`q4xX$scEqkpf(C=D${o?y~ zd3YOm@TlX8tx#M4@c6@lqJupgIleZxeywn3@Pn>D>-~8I+J>|2?5vZc#19^Y=7FUv zx&?1&>H6|FS3Gjl;amu+lK!HHawrOAaLU%<2xOm?yPyd|s1%IYS{+y%Iyj|z|AK9? zIXX?gzdVA)8i36o?yB?W7$1uWpz{OEAc(Y48z&G7L&5{wD`!Z5Va#x=>&Sv*`)yT* z_QT5bOK+QP%Fs;c+yY$foIB)2bz@gaUfnU9S!+Ss6XIOnLE3lKFG?-MF~=h)g3%sY z8|-4zefo*&p*s+k)Z+l^q094-?P=MvD@#kN?81R&SQzi(KfhwmZR_+e(3?f zosQ7Q4m*`5dA&>uJ&p|gjf*C1uP z^gh6R8@G7zLp2k0_61B)aR0`p5iw@9q4Z*S}@IKXY4LI@x~_Jy}b;D4C+9hZH_i$VpI`=sUgu(L!7TWBiEfF zkHmB2ka(`>bBCcB0a)@rsDId!_qOG{ZJ$;M!G;Hkw{N^P@m4Y)I4%c}q1)4kJvNW) z8y|jivwq)3{l3+TeEkWz{sgr9lm~z%55M56$W-nhALlkLaw0${VJb9Ju9CX7W>MV5 z8lh$PV^Hm9Hfy^#YP+*j`PzQDwjVX3)Bu*$Ucq0JImBkcIQm|{G=%=~BSdY#F~(5Z zlM01j*tp^`2u+L*1N52Rx6eb6?C1z4%%HI;W2dwCNQ&WyUy~bw1jgk+CN44|mWqXmRXebk zwwD02aP=`e(T=20h|ChQg7gY}tJX?8qpokK10#FbSvDk!GvuWeM7B=#{4>QxEl4Ro z?F3JWyEY#}H?uP;$8`#e7~;YSxv|6LzvT7{C$=%rx9noa$vL$r+L&pUGD_+6&f6*Q z71fDd(CPHSKhI=e+QIGC{Un<%tUL_(kwa!Hld;3gT%PqCIJ6L{Q!j&STJTuYp>;YR zW#)+t*BF$cr?7sz=w)1?x3tF=eP(-3I;_=Vn++e;#ycYf_hujXIE1m_GGxtT&%=GAOAfoB5wuii@F7|I9^lpM#|060@41tJXPmLUCefO1lK>T?fe z)qU)@V+6<#hwXm_BTzrg@3p}P(anaZHX5GFHyn~14(TsA(!>|~Me~j|`~m_hE7OK! zFNCa~#2kPIQvcdIhUkQLpf=ktZGF9V#W|IXC%7_v9pC?NZLtkJRBA7BCD!LOzt?OJ zuKna2Kl;W>^wD&_{(xM6K!3&Ag&M`p{v1FG8*6nWfy3DdrKZR`G{ufa!3JFQL|ynl zfcVqtc_}zY0AZQ_nm|f%W5XCmR&g`USxO?ILutGs#c#%OD*mvJm7k^4$d#sx(L^#q zO1siGRXAx^eHCFRm-ADsN$(gj9X-sw5c$#9JlQY`WK@JH5j1N;fnt9_fYt*$JWf5L zml~o}8F*eT?Wafv4{!CPCZ-kSdq0Snsv0HysNOLnh{TaRS1sEjtp^C}bM?Q|y!<6T zAj^$csBOZjbC(MZyV>D$bOxR2aJfQVE23Wn1JN*I!(?5r`*|yV^7X@V{V?8)7TR{v z2hD}5x`(tO*9DFPV9CjJO0Apj&J8z=KEHqZcTfM$nYYg5-Tkt=KgaG58^dzr$@Mz< znbCaX6}j%+U-*#1QyUMbD9Bh<>YP#SRId~;Ayv1hM+E+1^@^z=> zy3?5#aSYzS_wU%Do@Y{ z^BJ~u^6w0J608Qe71b+T(_fRWov}jTQoxpw3tspJG}rV?tO{WXQw`NjOZ?@Gtr^&? zABUs>w{>#?pTl8=RsM>ru|XdWV{q~3E8qPGpq6m?3DkkG zgkZ0L=1~KM<0qD!vbTNH+qdEETYWX}?U%j%aM?B|`W%9G47sXj8P+ow7M$c{aStfj zMYUkZiG>ZCt=p#9<1y`wLMKgB#l>+wEj^$_u4jJ{5A%PQp3qMwY^0v(M|P|unyDjH z@Qb#6D?{1tH|#B6 zHb)q73$h)I>g3is#-ox*Y*R!J(I0E!vk1%Z|J00Wv&ALP3S+i zU!E~KQZHgn1`(RLlAls~uJAt?#Mm3|qUyQI@d)eX?JXN5Rx_^XFNymgJrH_S`&YDk z$g)A8rqOW@YpO-=I4-xHC<=Hiv~**aSJ8psRtpZ6demEVQMmsY`vq^(+>z@zB6pmU zn@$%!c*Q>37944IxL2L4Uw zqtjw%z>zBU+x?1@2)yDM;yMO?0?n^e%2fh?L4Yaa*aRTWEAPz~gy6Hq57Id+(MCd| zs;Tc`#laF}vlZJaY))v|0K$T;HYe2ZyC9s(eVKbpsLN@0L1@ca?}D%=XT1x;ft>X& z2wgerT@ZHVtam}^$XV}#@NCX{7lfyC*1I6|=d5=@Xv}GM(cNG>gX7m8TwK0ebWpgx zM{d8!g0RBbFn@J0ARIec4Aj{U!1QK2P^hkdFeNt}kgE?C9eCcVXqGFmTIju7sH$T* zdg#e3RM!=K6!Hrd4Y`JWkV)mxVIwP+9|wJ}*=%%7&dxt8HIJGfwaLx<<>n(F@pPY9 zlizlH!(oGQZF`uj-}e!F^@%ybZTCT2O#ugejV!-VbWj*)6z$2TS1+#Bu3mcUd%4c1 h<@V$3O;NyW6Zp1mz=joG4zi;kQS2w%ykhh7e*q%F(SiT~ literal 0 HcmV?d00001 diff --git a/struct2tensor/expression_impl/__pycache__/placeholder.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/placeholder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62b71ca0946dbb2b3c8d03b346d07fae8447fce0 GIT binary patch literal 12019 zcmdT~Yit`=cAnu3-_(PWte0QoXG}?!5+{yr#WvQpEG4N{_Qr0~FbO6!XJk>PNbQ|b z9CPJeZQ$*C7cS~8vWS4XfPMt2gSbdhYytl$3KZzC7AQ!A3ljqvHBh8L|IttaECLkh zckYnG8B)@^L4l%I)Zu;KbI(2Zye|K$tt}!T%~$_x_LUui@b6TpCx0!oJ}wHvCxRv< z1x?g^SurW%Tgr+HzNBwKN=gfIQeN;U{R@F)K%_o$Hn&R$PEjBA_Z8zm)Tx-`O^IeS=y&EkZXz9^9 z_0aq_SBs=a=X>krfqF(j|2X>h`GhG!>$)ju-R>DblLX-l{M#wn&wKTtS8R)31A0Io z(EDe7ja8C^n(v~J*p@nk$u4A5X?-rA)pVxlA1*T8Ff#evXe4q*S=_KmVLJ8FbEcsn zlQZ&6$>lZO7*Xc(*YpoCPs(Y?7-%yiIbGLuOJp@Sc}5IrIcM>V&?QG3R&%=)&@gEJI}6cx@cv z#ni%LmS~x}qK`)+@4x@P8&4I-t($ooW77cWPc@an%8P$ zT^(!Za>#`YbsAAdM-vguwv;uom~-}{PEP$Gm0i+}L+{k4eQzY9JVSD5dg@0_OOA7D z-AF|Cx%b8~X8I48GC&3g=!|bZ3FVsmAizMJZjt|_TQ`&NTj5ETFX*|8x@pPN8N;+f z=NC;}nN-#aUBvu)F0EVsKgyvhAzEQ)pD4tgf|}=J3{*xML_|qmOqp}|4%SW_+1`cJ zqFO7eX1<=a0`_@W0j8TvEN2a>7aAA#jeOob=Nw}ilF~^Ma`eBh0DK~tuC`_2GnBq? zY9J-V6y&0CmNN#?%|&e2wN!{ts5Qkv0aLM20}Cq38Fs}O=SON+juWTXI)If^CcjH8 zbD1u+i`7l0>vp?Ivca`(jXxT4h_CYor$TM1Gs>p4u;(=9@=IBbt{^0hrljn(8K}nZ zbI!FlzT6d-TAUk2HX$++{`hLgmqDabzI|tGF2A6UEouu(jHRxRP3Xo|Gru@?EzhnR zix9A5hRK%F<`I+2%cfF;0n`jG?3in-ql?$Aom<{Y)(s3s2^G>jE5bK}LS+BiWGOsc z4i6W7!|SLOESc8uZ~m&t040U$UjP@CY4_6mm27^;g{ zF1f`>+nuKVdjVF2ufpATCLV7a`|S(U#h&j!mBh$-@oRx#RS>1~BI}_E{rr?YWTMS; zGOZy4V?(?tEDMxxO`oCa3#WQQ@Ro0gCe2ez^9~8kK)p^7%uplWP}j?%m*_WO>;Ma~ zc?BGXl!9?1OzBX~(53D>gPUp28MSNbIiDJ5gSEL1o`uS>M`!iatY?Y_pBm6uO~=fq z89}9{T*3O7j$dZ1N*4-wLwDUNqe(RI(c`WT3a=sg;knl-B@2;DGur#1Ohf3o#uEab zgCu^ZtFlg9@riH+$N1hL4#kmm0!ZqWeqY9y#d)Xw%l>8YSH9cZFAARv2_I)LD+mT8 zG9`SLZ}fRf{Lr9tRTR#j-$qUnvg!Q7;_1RmAc}`qIsUeKwNKz=HV^e+oE~kA`T-_0 zz6Y=(R0Sz6H$WxU_BgVC?FXgEbLGf$RX-Y^j2yoI8qZY%!K&ox&e_{T4&rN0j+c?C zk)!v9xGb)r+?u?+3^Mdx6xdGmgMJF}t>K_SmEEYaMBPmWwg+|JZU1j#q@e(;2*vP- z{aJnC)}B8b`sL6sMsC+?P|x=V-P}v4xY$&+4NogzUSEVJx8!sRf|&Mh^}1z^OU(G# zAT>u_v$%I-5qEESkZcI!C-85)4So_F|^3@ZO&XiEdnp59km^L=X|_g-wePGnzWM$QW(c`wo{hk}@+Svs`{)7&gYd9i81fcb>p{zQZJUQ$A*|u* z*}xJ$TQBa6*mp)rlaf>5QP{$^W<51{Z`2CF zq=K)-4q=2P8z!7DGCb16#P#Ip@>SIQTJ&mj?^&%m@~l3XzL%wjLPFkARLDE)+!ww6 zrwDnuET|4YE&EV%l-aU`lEYn^fL;4$AS~K#BT$`%qp`&FM9A_RON%?|8WFLsy!QOF9bYB~%5N>vR) zl6;C4E|}R1C>Hj+1wfOrF7IhtCHZ~_EZr4+^hhESpoQ@dthba=5!^#4PUQU?9$5g%5_nFBTl9X55uQXlZl5g5eH;zefjWFx*KXyTklmW1zlSnZ1PGY#)FXCS71}`Y2U4Z<{)8 z&1ZGhxqzO=L4XyZLTB?RsyvJ;rRd&rbnm0+$iwK!y-W9#rRec;^mtJ|-bxMU$jCDG zQsatP$6d5>Z$i`7nnr$!rQNQ-Q8=>Iff44dV-{&2F?%1e&fC+sP8a0?j_nMtzP*I) z0~uBfF%&bCOOf+zSizb{Y=w=Crb8CY=F^UB!Gt7P_;LRjTg3)t*5kI_8hp7dHy_-x z_%6oW@SCprmi;Cstf$7Iduru+yQkKfWmfVU=0=C<=s*Fkf)|gE#ASa&->$h#{rW19 z!@_*HUhbd7wiz>8Dj=9!b{x!E%-@(D}^13mWR_%{q+yY ztcO!ki*t-CiKrzlEG}AYDg^@6l%Z0Pk@W#@tG_YpG_3E&%+^4og^P4olWPzKM_etD zHz@*jje;Cj~O}xUI9?wrllzsVz}%OB4f%O2^KhzPj?}gP(XOiC^N&1^_1?L**7NCLsc*PAI$i3Ue%v>>`tse= zrM{tZ-w-_L{$ZMNgk~I}8ArZ~wwI$jR^KW`50p2N*^ZQ=+f3dxRGV*Rl6w-(($ngz z(+2W^Jbj1-#E;GEg%_WNpiP5*js@50d`Q4{-a2Y)KDsJtw<4fCmolU$NsKOwOLaZR zvHaX;GXstE_ytaEj}}H4FqJJFepWnr_g`Yh^_u96cHDXCmZPAp$ZST>YHBXEpnI?- zNyM@Ja~oUOWAn(_=~Di#cy4zboYIH@=&k}r}+ z@)~6W7x^;*x2~-aJx}P&yl&g#^5uL!%g7A0KDaQlgo6kv4ItVqKHXJ@G*s$PQF zx3vZ%ooEH^FuTEs>$#k^f;HvBNY1j01l}T`5+Kgu^nwbp9f7o-&8jLa=;LVO=IQrr zF-<(kT{2=o^=BtqAKz1^T!_ymQT!-pFeL`R0XQ`z;;m@SsS?b()|YEe(DB!=-M4S2$e@;sSTJRv?p@7Vn{ zyxWl5~FtX#I56G?16;E zdJ^EfL{IR`f~I$oT5QBwg}=8;g8+JyU*#-1Y3LtyhjNW)%bZ%|JPeYY&iqa2<;bT;rD+e~XC?a(LmAMubo=6svzi zglAan`-4*KY&mwe7&u!AcH&ve|Ek#HtA0ftCrh!Za%`#?nEJhpL*C#Pw~Nd>Z%Aqr zyqi3~f513)i2%Rj|8tj*y@|nYUymD^g^|q<*CC>*uvo7dHq*q*uktj{R$I zl@HJZa#f;=Cp|kH_s&lh@c3;1F33wEp=;pd-PLfda4_+yS}O>=$_@MfZydZy@s2^8 zo#emPX@VfkcXR*~+ikWzNKC*PMW!5QBhHv)Z&+=fpKjTVM*#^QxUl(|JUMHpr{8G9 z^5&r`MlS=d3qsG`HqZ0BK@O1B(FlS}(a+6n6OK$k9-QEO%HM%)^OLKyF0_6dI%&sh zPP3n(NcNa<4-oNSQLOIh7V*Oo0z-n>zS2g272#Cz**;H%eMRr5A`BP3pNg=*=>1fL zT}AJwBJ3%8KNVrB=>1d!ZQ@8J)^YpwszkZ%$M|01*fJ z-T=nB#KVSc7dUjv`i~l8!GQT_LLgFZnXYbJVw1+ q!P;PX&&yA#;p;{#OvS<7#oZ?=yT@@T$}0$(51uRvz4nKT)c*!*O&Bf! literal 0 HcmV?d00001 diff --git a/struct2tensor/expression_impl/__pycache__/project.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/project.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..526fab92df8b84f08ddcbccc9f70a4c4a101b9c1 GIT binary patch literal 6338 zcma(VYi|=*_Ks)#94F2@kU|;EBXJTOLg@l2uWAz>yCtv-+iFeGU_4{uV0+vo+yl566B+IWJQxcC?o$Kv9gUYNe$}gr z^T3PUwl)x=1X}5m9z|RhVxga;*y@fW>8g?w4@ZzO6u(#Rdqe3Wut*W zgE*zbtf6M5ysjz|Of+(>1wOU9fk-HfO6Dm*%KJD$m<)u}+> z>eZ_OTva+IMWfL`R=p+HyY=Ox_4^18k4u+xabqre0l>)Ra8lF4#7GeF+)wB@^~evntQ2{po5-A%4EHm}1V6Qr#k=n42{xO?l&SprJXax;c34^@pk z%_dxRGr*f4{%993{|(?z;461X5`1LYSKy|}Ll}Duf4eQY3NCui0F!An#Z_wQHAV?B z+)cdTdI&U+9Xf?CId<5jB193ENx`-Lo$b#)1kN8j+_%qt%o(_sdi#VFxOQr=cXxan zuwmPI)6797l5_zz20CoG0T^{0Y?-PV?WU?k9SUPEZm=19l{#7Y253JE#--;Ib71s9 zlS($t$uS}+rNc^~#n-xb6;JBwul}V3nQMRYO)(J)vfmc-DyUmMgm7 z;&LYta3adWPl2Op@Un!K$sJ#OtucNyt${xD?Kpz>h^z?m1;&E8Qhy zZFp*RcN0{Ebg|Bz9C7pA8vF2z|Zb;3JCPq0l+r* z&e+$JonPpWCO^CR`Auht4gv{;tH8S$$Y&_ksO5o3pNSh5pH{P$Ft4SuR*!taP9!sq zMMZd~S&A9b>Zr>@X7+_(PMR;EjPD*h4EKT>v$FcGx*h6ng#Ve=b69*|aT z%{xA12ryZ1?Vhp3705KXa)O1_w+gO#CnPFdGo>%_1+Kt1`=yrk8tW1Wnx&R-+ZSZJICA%7frFsP&r-CS?~v<=#}X47M5hF%AKelaH5rHGsH75*qtJV-Kr*GK6`A24 zNLS#>mPgMgld0Pg!PYdS8O*4(4^~8asvELys5u?km!L8Gw1(vCRLkdVdyua!$W2PR z4=6F}=vM#~$+zNwDTco|SP@4|aiqAnDt6Tf&+lWP7_WMLUQ-&;CVoHaSHKzipVhP4VPPIDq#GIOm%Sjl_Q0ptM`Ci8H&+jR(Qv>WgUM-Rj8D~=w4 zn~smvW6uCqBD$8R36&i^>i@+>$|9^6+6M$KTf z?1@&zo_mvbr^-VA#tiHj06IS$@iCXVW-n3PEv!LhRK!SzOl7OrKZ8o04b;d4c0D;alM9+ zof9Ck1_#4uA!js!jnE$#V!n=huQeR)afA zfl6@942~6Ns@`C6soK$f=TbG)d$+I_+OrzkQ<|@YCe6@f@l4elDtiaPP^XgV>n@gd6IDzQNVI*;yQlNi0**ob^q8WCPVL;HG{W zdM5VWyL4Y(>yNDVM=JeeX8%}495==BvM_FYJS`T=Ld?(mL5ZZC8!mc)fDv-JL6I;7%Z>-RY<28 z@htBJ`}-69HTHySnwF**Jm_iGzX%04J%@eN3jG}pcGS~+LQSV-8L}fg0b-auX{UP$ zNfDsI>Ps*KX}$5Qx+co5iJB1Phn{zCF*|qF+Jp^=&&I|0%&v@-nZgxNi6FA^Y!j+?(a(_E1)YlA+H? zIJ^8Jux$R|`Vu?ZD zLl8!QS(#OgO#5~Y9mFv#{D1NtwQc)2y$W5l7yT`OHJ;op#`yGBu0Pn4$zzd7CBILG+OT`q*Y2I0gLS0LE^1zqUX>hsFRV1`rTn;g9~&>jc;>&|=TI zLk>Bz++@+m_2}^2*SY7Od+s^s9{)=$#uF%Ot>5SVZIqDz!b+#m9;iJ1SB8)eh)gm> zW@KN1Df%+LBAa20T!v$?jV<_#flPqn9N@uBkm7#8Lzxi81AvD!VTuO<=QBLTLxo5& znu$_8T!>}jGOq;J`Z`R=5hc9V@7AIU*W26Mm-FrF3?sRK9J|S72L33mILsdGOp&9# zbaNrO59Sa33G@45{;-cM5qaPykq3eLNN25&Swil@pHniU|F2cP+5SWF5Uih&!+IWiR-+AX9UQ<;=I8A3y;i{O-!olOIaUA}GWJ*9p z#C67#bS_UBP;?-U9CTL|C{<%sGAKJzNl?qkwp7^0g+nbRTwBd&SNS|pDqNFp=mN5@ zl#G-hm1H5SmJBHmdp1@Tp;ReeRWw1(3BZhy$BmUG&B$jf1*qfZy89C6U}qqYZ?-WQ z_~+Xf68oBM!EOz$q6qmMZcI}I3I5a)J;b64czc6%PrpSqOR{$8Qd^LA=(6%|MJZ(! z+kd4}E+|RH9>_`sl!Kv&j=Xhfg6)mfXIM@zoic!%@SzlZE@h1R(cpIKx0Pq1ZKv_GbMrw?KeeJU5&ifcz z?sl4q_eO!vGu75&XGiq#9UafXnNwwU)I&E=Sb0ML#RpB3N^0pwQLX5p(^m^{J3t`> z;jFfzzXZPmRn3Hz?C^3^Eehi{+TnEIt*O?P1@Tq?dr2{NQVpsMZM#Ud zz!COeQ5R6xyc>6P2}OYB3zeX)p*_5PjryR9Xn5R7Mniv^WV>php=;7nn6R>jKGEjz z?6d@+r?77i064L+_K9u&q!FIC!t?juwZf-N-zjYv`aP!;ga&##B|x4%osWT&yY6Lv z8?beDo%39E*9_Fy8YeS1iNOJ>c3E>Z0nQmT;Vxg+G!e4X8)jXEAYguK+fLWc=74)Y*@1c2XAa-_{>enZX z?U#)kWhKekJ|kxbS1{I4G&=;|SU04SVf*tXWBwSp3+B3x{5U%IyU;5Isd!bE&QwR> z+|sWURPajr89^WBd;IQ#22|MUCi>~dpv*#bCz$r`^@ zCsZSo9JM)G1mur< z5g&WNxidh?EyyO8r+)zOb2vDKWEcp$eKM0_WuL;Wxmqr>y|o;7i{u!YyBW;{WWN&Z z?dw{-oKFrwP#6jmxnB;!sPI}R#3$WRa`<`kAxw>|q3fUp0`nM9hCf~5U;V&=6PCw&Iix)t4Y&tyk?Txm2$X5Sj$Z}~blwd|tN(!Yj`>@mk!jcQQ0wfS zAV8D^dZv_Q1^lLpu2ypIJ=g@thof0^FbGf}j4=d#Rjoj@3&YXQwi|oK!vs6vOtE90 zsU8D}#tMzp(;9&H8SgFhMlYN0o$Lq`f8iePXz;xEj)4fNMnrdH%?Guvnvl!n6W{WM zWWe_8l`>vJQOuX}hA7heprI$WgJOF-@ix*H1zccGR|+|85=LoB92yqIY(dg>QPgos zxF2=uFSKc>RTJJ5NO#tfp;JEw;D74mF&F;n(#@qu-0(wg`1W@{TKMo{gF9q#hfMBJ zb7)un+|AdTyC>@BEG}WvkKPjv^~_`A(0Ne#?g!)pQX^~d>lDCUiLUrS0?aZ5J!$CF zW&nJ`Xc#i|SRQ*6SIKdT85*b{>W+>*zJ>q#5o1#cLY}t=-dRByDnO%KWk& zbZF~#R919&cfs2WY@r=&$J%yS&&vuNOis#;+|HdLQ8I&71vV+MvCnMI@d8m&|+YS*e*5y#~{$$Lh+*^wGw^U`oLCAE;utuS6H-hg;L;-NXv;kOq0 zhTKOz;hL*=wP-CO`@L&{K@TifFSM;CeIE+Nica4@lJJ@$<(5@hxk4Won2C^t#npU4 z?lz``s})1IM&q~Y-+HXB^W5wqw{5CJ?@f4$?o3YG4LkoWD)xospbeP>$QMAAl2tVH zj&|J3o!z&Myv=_PpS$NZP1S4`o=G-`!H*sHu4qRcPXd=*R<6^)x*d_gyIlwCyefft zrsg;T76{&YqMR@4Hm_$T1EMFfB-=qqZQwB5uSum9B^jf3Oq&NbwG#+VBEWdvj`oV$ z9@;@*+9Hk{fTT#<+KXjPHB>N4m9h-+8X&-GJ_|^|dC=Sa7$x;?Lr2vHAJ8q%<00&3 z6RKm|EU4R4gsJ))AoeSS=zj}yB;I%Wz-8+(pfdyX{{$F0P1h~38Kp8&-ZNcIGh zJ<$r0VZj0M4rmWG_XAAc)n(@<(_-QMCy1ujp z{iTh!?;L5wr>*#OeF>sPX!^m%iOr>7z4>XzJo)Nx&i>|TWBQUceW?+9-HN?l$B1%l z_j}h(;mE%aem-U<78;2KE3sg5!&{L-EAoOlK6mfJgY%8Zf)#<*1t+j{Ej7p)aq34f z)dHDZ<}%Ty0R4n$=#_gtVWM+D;Go33E`nbw8F}MI_3(2Ja@)9nM0#k+!52pQZyl%1 z*c>`lPtu2jA(W*QAl~2zQa-LR0k;Bp_oWDoAB73v&R>BaZNkCPoon9((9UN#TIlAq zkVVLNNrkHLM>au6{1Xy!TsbPD{LykFGG#@kOm51FL+I+m7)If1TofVaA{JG-Qb0T^ zinvU>Cn$=tnuSOU#bNX3-B6YWOBzZ*JC1k$rSsYxrm@NSi{W{{^dWvK6K^o(%5ZnD}tU8hEiyi68Q2a244c~g8|&aH>_|P3>CZO+jtZbKJ)u6HAJydz5G;midS7qw zBQxIM;Aar9?;P^iwvDXMc^4hn7H#?5OKZNgDcu`^! zBLM%8Qj`qsJJ`u%r{fy!FydK^Uo;FZ?J#7^q0d;==@qmSUPAEZ5*3Udo=r&V&0$bvPp(BKy>e$nI>oBewp^&fuN zf4I>`N3rpTvGL6_jo6$Oo2yT5g%eizg=T=38JF%p_nf>rM*24I=Rm55p}Mnwp1at-~)< zl=k$1$$2K#BF}=`nfJse*#2d=Q)Bv=p(f<-AV expression.Expression: @@ -80,7 +77,6 @@ def _normalize_feature(feature: schema_pb2.Feature, feature: feature to modify in place. schema: schema containing any global domains. """ - if feature.HasField("struct_domain"): for x in feature.struct_domain.feature: _normalize_feature(x, schema) @@ -97,8 +93,7 @@ def _normalize_feature(feature: schema_pb2.Feature, if float_domain.name == feature.domain: feature.float_domain.CopyFrom(float_domain) return - raise ValueError("Did not find domain {} in schema {}".format( - feature.domain, schema)) + raise ValueError(f"Did not find domain {feature.domain} in schema {schema}") def _clean_feature(feature: schema_pb2.Feature) -> schema_pb2.Feature: diff --git a/struct2tensor/expression_impl/apply_schema_test.py b/struct2tensor/expression_impl/apply_schema_test.py index 05ffee8..c40fe23 100644 --- a/struct2tensor/expression_impl/apply_schema_test.py +++ b/struct2tensor/expression_impl/apply_schema_test.py @@ -14,14 +14,13 @@ """Tests for struct2tensor.expression_impl.apply_schema.""" import copy + from absl.testing import absltest +from tensorflow_metadata.proto.v0 import schema_pb2 -from struct2tensor import create_expression -from struct2tensor import path +from struct2tensor import create_expression, path from struct2tensor.test import prensor_test_util -from tensorflow_metadata.proto.v0 import schema_pb2 - def _features_as_map(feature_list): return {feature.name: feature for feature in feature_list} diff --git a/struct2tensor/expression_impl/broadcast.py b/struct2tensor/expression_impl/broadcast.py index 2eb0873..ac63eca 100644 --- a/struct2tensor/expression_impl/broadcast.py +++ b/struct2tensor/expression_impl/broadcast.py @@ -82,14 +82,11 @@ from typing import FrozenSet, Optional, Sequence, Tuple -from struct2tensor import calculate_options -from struct2tensor import expression -from struct2tensor import expression_add -from struct2tensor import path -from struct2tensor import prensor -from struct2tensor.ops import struct2tensor_ops import tensorflow as tf +from struct2tensor import calculate_options, expression, expression_add, path, prensor +from struct2tensor.ops import struct2tensor_ops + class _BroadcastExpression(expression.Leaf): """A broadcast field.""" diff --git a/struct2tensor/expression_impl/broadcast_test.py b/struct2tensor/expression_impl/broadcast_test.py index 8aa61a6..081cf33 100644 --- a/struct2tensor/expression_impl/broadcast_test.py +++ b/struct2tensor/expression_impl/broadcast_test.py @@ -14,15 +14,15 @@ """Tests for struct2tensor.broadcast.""" -from absl.testing import absltest -from struct2tensor import create_expression -from struct2tensor import path -from struct2tensor.expression_impl import broadcast -from struct2tensor.test import expression_test_util -from struct2tensor.test import prensor_test_util import tensorflow as tf +from absl.testing import absltest +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import +from struct2tensor import create_expression, path +from struct2tensor.expression_impl import broadcast +from struct2tensor.test import expression_test_util, prensor_test_util class BroadcastTest(tf.test.TestCase): diff --git a/struct2tensor/expression_impl/depth_limit.py b/struct2tensor/expression_impl/depth_limit.py index cc5a2f2..34108c3 100644 --- a/struct2tensor/expression_impl/depth_limit.py +++ b/struct2tensor/expression_impl/depth_limit.py @@ -41,10 +41,7 @@ from typing import FrozenSet, Optional, Sequence -from struct2tensor import calculate_options -from struct2tensor import expression -from struct2tensor import path -from struct2tensor import prensor +from struct2tensor import calculate_options, expression, path, prensor def limit_depth(expr: expression.Expression, diff --git a/struct2tensor/expression_impl/depth_limit_test.py b/struct2tensor/expression_impl/depth_limit_test.py index 3841c7f..a7472f9 100644 --- a/struct2tensor/expression_impl/depth_limit_test.py +++ b/struct2tensor/expression_impl/depth_limit_test.py @@ -16,8 +16,7 @@ from absl.testing import absltest -from struct2tensor import create_expression -from struct2tensor import path +from struct2tensor import create_expression, path from struct2tensor.expression_impl import depth_limit from struct2tensor.test import prensor_test_util diff --git a/struct2tensor/expression_impl/filter_expression.py b/struct2tensor/expression_impl/filter_expression.py index a791f05..dca3fe0 100644 --- a/struct2tensor/expression_impl/filter_expression.py +++ b/struct2tensor/expression_impl/filter_expression.py @@ -62,21 +62,17 @@ from typing import FrozenSet, Optional, Sequence, Union -from struct2tensor import calculate_options -from struct2tensor import expression -from struct2tensor import expression_add -from struct2tensor import path -from struct2tensor import prensor -from struct2tensor.ops import struct2tensor_ops import tensorflow as tf +from struct2tensor import calculate_options, expression, expression_add, path, prensor +from struct2tensor.ops import struct2tensor_ops + def filter_by_sibling(expr: expression.Expression, p: path.Path, sibling_field_name: path.Step, new_field_name: path.Step) -> expression.Expression: """Filter an expression by its sibling. - This is similar to boolean_mask. The shape of the path being filtered and the sibling must be identical (e.g., each parent object must have an equal number of source and sibling children). diff --git a/struct2tensor/expression_impl/filter_expression_test.py b/struct2tensor/expression_impl/filter_expression_test.py index 360b292..620ecfc 100644 --- a/struct2tensor/expression_impl/filter_expression_test.py +++ b/struct2tensor/expression_impl/filter_expression_test.py @@ -13,29 +13,27 @@ # limitations under the License. """Tests for struct2tensor.promote.""" -from absl.testing import absltest -from struct2tensor import calculate -from struct2tensor import create_expression -from struct2tensor import path -from struct2tensor import prensor -# For tf.Session.Run against a Prensor -from struct2tensor import prensor_value # pylint: disable=unused-import -from struct2tensor.expression_impl import filter_expression -from struct2tensor.expression_impl import proto_test_util -from struct2tensor.test import expression_test_util -from struct2tensor.test import prensor_test_util -from struct2tensor.test import test_pb2 import tensorflow as tf +from absl.testing import absltest +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) - -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import +# For tf.Session.Run against a Prensor +from struct2tensor import ( + calculate, + create_expression, + path, + prensor, # pylint: disable=unused-import +) +from struct2tensor.expression_impl import filter_expression, proto_test_util +from struct2tensor.test import expression_test_util, prensor_test_util, test_pb2 def _create_slice_and_project_example(): r"""Creates an example for test_slice_and_project. Returns: - ------*----------------- / \ ---session0---- session1 diff --git a/struct2tensor/expression_impl/index.py b/struct2tensor/expression_impl/index.py index a18d457..8ce3329 100644 --- a/struct2tensor/expression_impl/index.py +++ b/struct2tensor/expression_impl/index.py @@ -45,8 +45,7 @@ get_positional_index(expr, path.Path(["event","val"]), "val_index") ``` -yields: - +Yields: ``` session: { event: { @@ -119,14 +118,11 @@ from typing import Optional, Sequence, Tuple -from struct2tensor import calculate_options -from struct2tensor import expression -from struct2tensor import expression_add -from struct2tensor import path -from struct2tensor import prensor -from struct2tensor.expression_impl import size import tensorflow as tf +from struct2tensor import calculate_options, expression, expression_add, path, prensor +from struct2tensor.expression_impl import size + def get_positional_index(expr: expression.Expression, source_path: path.Path, new_field_name: path.Step diff --git a/struct2tensor/expression_impl/index_test.py b/struct2tensor/expression_impl/index_test.py index e0df610..295a6ad 100644 --- a/struct2tensor/expression_impl/index_test.py +++ b/struct2tensor/expression_impl/index_test.py @@ -13,16 +13,15 @@ # limitations under the License. """Tests for struct2tensor.promote.""" -from struct2tensor import create_expression -from struct2tensor import path -from struct2tensor.expression_impl import index -from struct2tensor.test import expression_test_util -from struct2tensor.test import prensor_test_util - import tensorflow as tf - from absl.testing import absltest -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) + +from struct2tensor import create_expression, path +from struct2tensor.expression_impl import index +from struct2tensor.test import expression_test_util, prensor_test_util class IndexTest(absltest.TestCase): @@ -76,14 +75,12 @@ def test_get_index_from_end_calculate(self): prensor_test_util.create_nested_prensor()) new_root, new_path = index.get_index_from_end( expr, path.Path(["user", "friends"]), path.get_anonymous_field()) - print("test_get_index_from_end_calculate: new_path: {}".format(new_path)) + print(f"test_get_index_from_end_calculate: new_path: {new_path}") new_field = new_root.get_descendant_or_error(new_path) - print("test_get_index_from_end_calculate: new_field: {}".format( - str(new_field))) + print(f"test_get_index_from_end_calculate: new_field: {str(new_field)}") leaf_node = expression_test_util.calculate_value_slowly(new_field) - print("test_get_index_from_end_calculate: leaf_node: {}".format( - str(leaf_node))) + print(f"test_get_index_from_end_calculate: leaf_node: {str(leaf_node)}") self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 2, 3]) self.assertAllEqual(leaf_node.values, [-1, -2, -1, -1, -1]) diff --git a/struct2tensor/expression_impl/map_prensor.py b/struct2tensor/expression_impl/map_prensor.py index cec43b1..132fdd7 100644 --- a/struct2tensor/expression_impl/map_prensor.py +++ b/struct2tensor/expression_impl/map_prensor.py @@ -101,14 +101,11 @@ from typing import Callable, FrozenSet, Optional, Sequence, Tuple -from struct2tensor import calculate_options -from struct2tensor import expression -from struct2tensor import expression_add -from struct2tensor import path -from struct2tensor import prensor -from struct2tensor.expression_impl import project import tensorflow as tf +from struct2tensor import calculate_options, expression, expression_add, path, prensor +from struct2tensor.expression_impl import project + def map_sparse_tensor(root: expression.Expression, root_path: path.Path, paths: Sequence[path.Path], @@ -132,7 +129,6 @@ def map_sparse_tensor(root: expression.Expression, root_path: path.Path, A new root expression containing the old root expression plus the new path, root_path.get_child(new_field_name), with the result of the operation. """ - return _map_sparse_tensor_impl(root, root_path, paths, operation, is_repeated, dtype, new_field_name)[0] @@ -293,8 +289,7 @@ def new_op(pren: prensor.Prensor, result = _as_leaf_node(result_as_tensor, is_repeated, sparse_tensors[0].dense_shape[0], options) if result.values.dtype != dtype: - raise ValueError("Type unmatched: actual ({})!= expected ({})".format( - str(result.values.dtype), str(dtype))) + raise ValueError(f"Type unmatched: actual ({str(result.values.dtype)})!= expected ({str(dtype)})") return result return _map_prensor_impl(root, root_path, paths, new_op, is_repeated, dtype, @@ -366,8 +361,7 @@ def new_op(tree: prensor.Prensor, result = _ragged_as_leaf_node(result_as_tensor, is_repeated, ragged_tensors[0], options) if result.values.dtype != dtype: - raise ValueError("Type unmatched: actual ({})!= expected ({})".format( - str(result.values.dtype), str(dtype))) + raise ValueError(f"Type unmatched: actual ({str(result.values.dtype)})!= expected ({str(dtype)})") return result return _map_prensor_impl(root, root_path, paths, new_op, is_repeated, dtype, diff --git a/struct2tensor/expression_impl/map_prensor_test.py b/struct2tensor/expression_impl/map_prensor_test.py index 727407a..1b173b4 100644 --- a/struct2tensor/expression_impl/map_prensor_test.py +++ b/struct2tensor/expression_impl/map_prensor_test.py @@ -13,18 +13,15 @@ # limitations under the License. """Tests for struct2tensor.map_prensor.""" -from absl.testing import absltest -from struct2tensor import calculate_options -from struct2tensor import create_expression -from struct2tensor import path -from struct2tensor import prensor -from struct2tensor.expression_impl import map_prensor -from struct2tensor.test import expression_test_util -from struct2tensor.test import prensor_test_util import tensorflow as tf +from absl.testing import absltest +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) - -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import +from struct2tensor import calculate_options, create_expression, path, prensor +from struct2tensor.expression_impl import map_prensor +from struct2tensor.test import expression_test_util, prensor_test_util def _create_one_value_prensor(): diff --git a/struct2tensor/expression_impl/map_prensor_to_prensor.py b/struct2tensor/expression_impl/map_prensor_to_prensor.py index ce3246b..0a66635 100644 --- a/struct2tensor/expression_impl/map_prensor_to_prensor.py +++ b/struct2tensor/expression_impl/map_prensor_to_prensor.py @@ -71,17 +71,13 @@ from typing import Any, Callable, Dict, FrozenSet, Optional, Sequence, Union -from struct2tensor import calculate_options -from struct2tensor import expression -from struct2tensor import expression_add -from struct2tensor import path -from struct2tensor import prensor import tensorflow as tf - from tensorflow_metadata.proto.v0 import schema_pb2 +from struct2tensor import calculate_options, expression, expression_add, path, prensor + -class Schema(object): +class Schema: """A finite schema for a prensor. Effectively, this stores everything for the prensor but the tensors @@ -138,15 +134,11 @@ def known_field_names(self) -> FrozenSet[path.Step]: return frozenset(self._children.keys()) def __str__(self) -> str: - return ("Schema(is_repeated:{is_repeated} type:{type}" + return (f"Schema(is_repeated:{self._is_repeated} type:{self._type}" " " - "schema_feature:({schema_feature})" + f"schema_feature:({self._schema_feature})" " " - "children:{children})").format( - is_repeated=self._is_repeated, - type=self._type, - schema_feature=self._schema_feature, - children=self._children) + f"children:{self._children})") def create_schema(is_repeated: bool = True, diff --git a/struct2tensor/expression_impl/map_prensor_to_prensor_test.py b/struct2tensor/expression_impl/map_prensor_to_prensor_test.py index 5edf25a..b7fd2c7 100644 --- a/struct2tensor/expression_impl/map_prensor_to_prensor_test.py +++ b/struct2tensor/expression_impl/map_prensor_to_prensor_test.py @@ -13,19 +13,22 @@ # limitations under the License. """Tests for struct2tensor.expression_impl.map_prensor_to_prensor.""" +import tensorflow as tf from absl.testing import absltest -from struct2tensor import calculate -from struct2tensor import create_expression -from struct2tensor import path -from struct2tensor import prensor +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) +from tensorflow_metadata.proto.v0 import schema_pb2 + # For tf.Session.Run against a Prensor -from struct2tensor import prensor_value # pylint: disable=unused-import +from struct2tensor import ( + calculate, + create_expression, + path, + prensor, # pylint: disable=unused-import +) from struct2tensor.expression_impl import map_prensor_to_prensor from struct2tensor.test import prensor_test_util -import tensorflow as tf - -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import -from tensorflow_metadata.proto.v0 import schema_pb2 @test_util.run_all_in_graph_and_eager_modes diff --git a/struct2tensor/expression_impl/map_values.py b/struct2tensor/expression_impl/map_values.py index f02d0c8..a781ffb 100644 --- a/struct2tensor/expression_impl/map_values.py +++ b/struct2tensor/expression_impl/map_values.py @@ -23,13 +23,10 @@ from typing import Callable, FrozenSet, Optional, Sequence, Tuple -from struct2tensor import calculate_options -from struct2tensor import expression -from struct2tensor import expression_add -from struct2tensor import path -from struct2tensor import prensor import tensorflow as tf +from struct2tensor import calculate_options, expression, expression_add, path, prensor + def map_many_values( root: expression.Expression, parent_path: path.Path, @@ -114,7 +111,7 @@ def map_values(root: expression.Expression, source_path: path.Path, def _leaf_node_or_error(node: prensor.NodeTensor) -> prensor.LeafNodeTensor: if isinstance(node, prensor.LeafNodeTensor): return node - raise ValueError('node is {} not LeafNodeTensor'.format(str(type(node)))) + raise ValueError(f'node is {str(type(node))} not LeafNodeTensor') class _MapValuesExpression(expression.Expression): diff --git a/struct2tensor/expression_impl/map_values_test.py b/struct2tensor/expression_impl/map_values_test.py index cdd003f..69afd71 100644 --- a/struct2tensor/expression_impl/map_values_test.py +++ b/struct2tensor/expression_impl/map_values_test.py @@ -13,17 +13,15 @@ # limitations under the License. """Tests for struct2tensor.broadcast.""" -from absl.testing import absltest -from struct2tensor import create_expression -from struct2tensor import path -from struct2tensor import prensor -from struct2tensor.expression_impl import map_values -from struct2tensor.test import expression_test_util -from struct2tensor.test import prensor_test_util import tensorflow as tf +from absl.testing import absltest +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) - -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import +from struct2tensor import create_expression, path, prensor +from struct2tensor.expression_impl import map_values +from struct2tensor.test import expression_test_util, prensor_test_util @test_util.run_all_in_graph_and_eager_modes diff --git a/struct2tensor/expression_impl/parquet.py b/struct2tensor/expression_impl/parquet.py index ff66324..d85f8ee 100644 --- a/struct2tensor/expression_impl/parquet.py +++ b/struct2tensor/expression_impl/parquet.py @@ -32,15 +32,12 @@ import pyarrow as pa import pyarrow.parquet as pq -from struct2tensor import calculate -from struct2tensor import calculate_options -from struct2tensor import expression -from struct2tensor import path -from struct2tensor import prensor +import tensorflow as tf + +from struct2tensor import calculate, calculate_options, expression, path, prensor from struct2tensor.expression_impl import map_prensor_to_prensor as mpp from struct2tensor.expression_impl import placeholder from struct2tensor.ops import gen_parquet_dataset -import tensorflow as tf def create_expression_from_parquet_file( @@ -54,7 +51,6 @@ def create_expression_from_parquet_file( A PlaceholderRootExpression that should be used as the root of an expression graph. """ - metadata = pq.ParquetFile(filenames[0]).metadata parquet_schema = metadata.schema arrow_schema = parquet_schema.to_arrow_schema() @@ -328,7 +324,6 @@ def _create_children_spec( Returns: a child or leaf _PrensorTypeSpec. """ - # pylint: disable=protected-access curr_steps_as_set = collections.OrderedDict() # Construct the dictionary of paths we need. @@ -370,7 +365,6 @@ def _create_prensor_spec(self) -> prensor._PrensorTypeSpec: # pylint: disable=p Returns: a root _PrensorTypeSpec. """ - metadata = pq.ParquetFile(self._filenames[0]).metadata parquet_schema = metadata.schema arrow_schema = parquet_schema.to_arrow_schema() diff --git a/struct2tensor/expression_impl/parquet_test.py b/struct2tensor/expression_impl/parquet_test.py index 7712f6d..efb0a2a 100644 --- a/struct2tensor/expression_impl/parquet_test.py +++ b/struct2tensor/expression_impl/parquet_test.py @@ -13,15 +13,15 @@ # limitations under the License. """Tests for struct2tensor.parquet.""" -from pyarrow.lib import ArrowIOError -from struct2tensor import path -from struct2tensor import prensor -from struct2tensor.expression_impl import parquet -from struct2tensor.expression_impl import project -from struct2tensor.expression_impl import promote import tensorflow.compat.v2 as tf from absl.testing import absltest -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import +from pyarrow.lib import ArrowIOError +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) + +from struct2tensor import path, prensor +from struct2tensor.expression_impl import parquet, project, promote tf.enable_v2_behavior() diff --git a/struct2tensor/expression_impl/parse_message_level_ex.py b/struct2tensor/expression_impl/parse_message_level_ex.py index cba7895..9f072f6 100644 --- a/struct2tensor/expression_impl/parse_message_level_ex.py +++ b/struct2tensor/expression_impl/parse_message_level_ex.py @@ -75,14 +75,14 @@ # pylint: disable=protected-access import collections -from typing import List, Mapping, Optional, Sequence, Set +from typing import Mapping, Optional, Sequence, Set -from struct2tensor import path -from struct2tensor.ops import struct2tensor_ops import tensorflow as tf - from google.protobuf import descriptor +from struct2tensor import path +from struct2tensor.ops import struct2tensor_ops + # To the best of my knowledge, ProtoFieldNames ARE strings. ProtoFieldName = str ProtoFullName = str diff --git a/struct2tensor/expression_impl/parse_message_level_ex_test.py b/struct2tensor/expression_impl/parse_message_level_ex_test.py index cffc37c..c6e1aed 100644 --- a/struct2tensor/expression_impl/parse_message_level_ex_test.py +++ b/struct2tensor/expression_impl/parse_message_level_ex_test.py @@ -18,16 +18,15 @@ work correctly. """ -from absl.testing import absltest -from absl.testing import parameterized -from struct2tensor.expression_impl import parse_message_level_ex -from struct2tensor.test import test_any_pb2 -from struct2tensor.test import test_map_pb2 -from struct2tensor.test import test_pb2 import tensorflow as tf - +from absl.testing import absltest, parameterized from google.protobuf import text_format -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) + +from struct2tensor.expression_impl import parse_message_level_ex +from struct2tensor.test import test_any_pb2, test_map_pb2, test_pb2 _INDEX = "index" _VALUE = "value" diff --git a/struct2tensor/expression_impl/placeholder.py b/struct2tensor/expression_impl/placeholder.py index 782e9a7..f7a79c8 100644 --- a/struct2tensor/expression_impl/placeholder.py +++ b/struct2tensor/expression_impl/placeholder.py @@ -35,11 +35,7 @@ import typing from typing import FrozenSet, List, Optional, Sequence, Union -from struct2tensor import calculate -from struct2tensor import calculate_options -from struct2tensor import expression -from struct2tensor import path -from struct2tensor import prensor +from struct2tensor import calculate, calculate_options, expression, path, prensor from struct2tensor.expression_impl import map_prensor_to_prensor as mpp @@ -55,7 +51,6 @@ def create_expression_from_schema( A PlaceholderRootExpression that should be used as the root of an expression graph. """ - return _PlaceholderRootExpression(schema) @@ -201,7 +196,7 @@ def _get_child_impl(self, return _PlaceholderChildExpression(self, field_name, child_schema) def __str__(self) -> str: - return "_PlaceholderRootExpression: {}".format(str(self._schema)) + return f"_PlaceholderRootExpression: {str(self._schema)}" def known_field_names(self) -> FrozenSet[path.Step]: return self._schema.known_field_names() diff --git a/struct2tensor/expression_impl/placeholder_test.py b/struct2tensor/expression_impl/placeholder_test.py index 44fc0a4..957a4a4 100644 --- a/struct2tensor/expression_impl/placeholder_test.py +++ b/struct2tensor/expression_impl/placeholder_test.py @@ -13,18 +13,16 @@ # limitations under the License. """Tests for struct2tensor.expression_impl.placeholder.""" +import tensorflow as tf from absl.testing import absltest -from struct2tensor import calculate -from struct2tensor import path -from struct2tensor import prensor +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) + +from struct2tensor import calculate, path, prensor from struct2tensor.expression_impl import map_prensor_to_prensor as mpp -from struct2tensor.expression_impl import placeholder -from struct2tensor.expression_impl import project -from struct2tensor.expression_impl import promote +from struct2tensor.expression_impl import placeholder, project, promote from struct2tensor.test import prensor_test_util -import tensorflow as tf - -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import @test_util.run_all_in_graph_and_eager_modes diff --git a/struct2tensor/expression_impl/project.py b/struct2tensor/expression_impl/project.py index 1983698..4c69b49 100644 --- a/struct2tensor/expression_impl/project.py +++ b/struct2tensor/expression_impl/project.py @@ -16,7 +16,6 @@ project is often used right before calculating the value. Example: - ``` expr = ... new_expr = project.project(expr, [path.Path(["foo","bar"]), @@ -31,15 +30,12 @@ import collections from typing import FrozenSet, List, Mapping, Optional, Sequence -from struct2tensor import calculate_options -from struct2tensor import expression -from struct2tensor import path -from struct2tensor import prensor +from struct2tensor import calculate_options, expression, path, prensor def project(expr: expression.Expression, paths: Sequence[path.Path]) -> expression.Expression: - """select a subtree. + """Select a subtree. Paths not selected are removed. Paths that are selected are "known", such that if calculate_prensors is @@ -113,7 +109,7 @@ def _get_child_impl(self, original = self._origin.get_child(field_name) if original is None: raise ValueError( - "Project a field that doesn't exist: {}".format(field_name)) + f"Project a field that doesn't exist: {field_name}") return _ProjectExpression(original, paths) def known_field_names(self) -> FrozenSet[path.Step]: diff --git a/struct2tensor/expression_impl/project_test.py b/struct2tensor/expression_impl/project_test.py index 75c759d..e8ac610 100644 --- a/struct2tensor/expression_impl/project_test.py +++ b/struct2tensor/expression_impl/project_test.py @@ -14,8 +14,8 @@ """Tests for struct2tensor.project.""" from absl.testing import absltest -from struct2tensor import create_expression -from struct2tensor import path + +from struct2tensor import create_expression, path from struct2tensor.expression_impl import project from struct2tensor.test import prensor_test_util diff --git a/struct2tensor/expression_impl/promote.py b/struct2tensor/expression_impl/promote.py index 3f14a19..b2a1f92 100644 --- a/struct2tensor/expression_impl/promote.py +++ b/struct2tensor/expression_impl/promote.py @@ -98,15 +98,11 @@ from typing import FrozenSet, Optional, Sequence, Tuple -from struct2tensor import calculate_options -from struct2tensor import expression -from struct2tensor import expression_add -from struct2tensor import path -from struct2tensor import prensor import tensorflow as tf - from tensorflow_metadata.proto.v0 import schema_pb2 +from struct2tensor import calculate_options, expression, expression_add, path, prensor + class PromoteExpression(expression.Leaf): """A promoted leaf.""" @@ -326,7 +322,7 @@ def _promote_impl(root: expression.Expression, p: path.Path, An _AddPathsExpression that wraps a PromoteExpression. """ if len(p) < 2: - raise ValueError("Cannot do a promotion beyond the root: {}".format(str(p))) + raise ValueError(f"Cannot do a promotion beyond the root: {str(p)}") parent_path = p.get_parent() grandparent_path = parent_path.get_parent() diff --git a/struct2tensor/expression_impl/promote_and_broadcast.py b/struct2tensor/expression_impl/promote_and_broadcast.py index c916ccd..238aff3 100644 --- a/struct2tensor/expression_impl/promote_and_broadcast.py +++ b/struct2tensor/expression_impl/promote_and_broadcast.py @@ -112,11 +112,8 @@ from typing import Mapping, Tuple -from struct2tensor import expression -from struct2tensor import expression_add -from struct2tensor import path -from struct2tensor.expression_impl import broadcast -from struct2tensor.expression_impl import promote +from struct2tensor import expression, expression_add, path +from struct2tensor.expression_impl import broadcast, promote def promote_and_broadcast_anonymous( @@ -161,7 +158,6 @@ def promote_and_broadcast(root: expression.Expression, A new expression, where all the origin paths are promoted and broadcast until they are children of dest_path_parent. """ - result_paths = {} # Here, we branch out and create a different tree for each field that is # promoted and broadcast. diff --git a/struct2tensor/expression_impl/promote_and_broadcast_test.py b/struct2tensor/expression_impl/promote_and_broadcast_test.py index 9b489c1..eafa6ce 100644 --- a/struct2tensor/expression_impl/promote_and_broadcast_test.py +++ b/struct2tensor/expression_impl/promote_and_broadcast_test.py @@ -13,13 +13,12 @@ # limitations under the License. """Tests for struct2tensor.broadcast.""" +import tensorflow as tf from absl.testing import absltest -from struct2tensor import create_expression -from struct2tensor import path + +from struct2tensor import create_expression, path from struct2tensor.expression_impl import promote_and_broadcast -from struct2tensor.test import expression_test_util -from struct2tensor.test import prensor_test_util -import tensorflow as tf +from struct2tensor.test import expression_test_util, prensor_test_util class PromoteAndBroadcastTest(absltest.TestCase): diff --git a/struct2tensor/expression_impl/promote_test.py b/struct2tensor/expression_impl/promote_test.py index 691fd97..71490ae 100644 --- a/struct2tensor/expression_impl/promote_test.py +++ b/struct2tensor/expression_impl/promote_test.py @@ -13,17 +13,17 @@ # limitations under the License. """Tests for struct2tensor.promote.""" -from struct2tensor import create_expression -from struct2tensor import path -from struct2tensor.expression_impl import promote -from struct2tensor.test import expression_test_util -from struct2tensor.test import prensor_test_util import tensorflow as tf - from absl.testing import absltest -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) from tensorflow_metadata.proto.v0 import schema_pb2 +from struct2tensor import create_expression, path +from struct2tensor.expression_impl import promote +from struct2tensor.test import expression_test_util, prensor_test_util + class PromoteTest(absltest.TestCase): diff --git a/struct2tensor/expression_impl/proto.py b/struct2tensor/expression_impl/proto.py index ce61c9d..95a03ae 100644 --- a/struct2tensor/expression_impl/proto.py +++ b/struct2tensor/expression_impl/proto.py @@ -19,21 +19,26 @@ """ import abc -from typing import Callable, FrozenSet, Mapping, Optional, Sequence, Set, Tuple, Union, cast +from typing import ( + Callable, + FrozenSet, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Union, + cast, +) -from struct2tensor import calculate_options -from struct2tensor import expression -from struct2tensor import expression_add -from struct2tensor import path -from struct2tensor import prensor -from struct2tensor.expression_impl import parse_message_level_ex -from struct2tensor.ops import struct2tensor_ops import tensorflow as tf - -from google.protobuf import descriptor_pb2 -from google.protobuf import descriptor +from google.protobuf import descriptor, descriptor_pb2 from google.protobuf.descriptor_pool import DescriptorPool +from struct2tensor import calculate_options, expression, expression_add, path, prensor +from struct2tensor.expression_impl import parse_message_level_ex +from struct2tensor.ops import struct2tensor_ops + # To the best of my knowledge, ProtoFieldNames ARE strings. # Also includes extensions, encoded in a parentheses like (foo.bar.Baz). ProtoFieldName = str @@ -70,7 +75,6 @@ def create_expression_from_file_descriptor_set( Returns: An expression. """ - pool = DescriptorPool() for f in file_descriptor_set.file: # This method raises if f's dependencies have not been added. @@ -173,8 +177,7 @@ def transform_fn(parent_indices, values): source_expr = expr.get_descendant_or_error(source_path) if not isinstance(source_expr, _ProtoChildExpression): raise ValueError( - "Expected _ProtoChildExpression for field {}, but found {}.".format( - str(source_path), source_expr)) + f"Expected _ProtoChildExpression for field {str(source_path)}, but found {source_expr}.") if isinstance(source_expr, _TransformProtoChildExpression): # In order to be able to propagate fields needed for parsing, the source @@ -290,8 +293,7 @@ def calculate( parent_value, _ProtoChildNodeTensor): parsed_field = parent_value.fields.get(self.name_as_field) if parsed_field is None: - raise ValueError("Cannot find {} in {}".format( - str(self), str(parent_value))) + raise ValueError(f"Cannot find {str(self)} in {str(parent_value)}") return self.calculate_from_parsed_field(parsed_field, destinations, options) raise ValueError("Not a _ParentProtoNodeTensor: " + str(type(parent_value))) @@ -357,8 +359,7 @@ def known_field_names(self) -> FrozenSet[path.Step]: return frozenset() def __str__(self) -> str: - return "_ProtoLeafExpression: {} from {}".format(self.name_as_field, - self._parent) + return f"_ProtoLeafExpression: {self.name_as_field} from {self._parent}" class _ProtoChildExpression(_AbstractProtoChildExpression): @@ -379,6 +380,7 @@ def __init__(self, parent: "_ParentProtoExpression", This does not take a field descriptor so it can represent syntactic sugar fields such as Any and Maps. + Args: parent: the parent. desc: the message descriptor of the submessage represented by this @@ -427,8 +429,7 @@ def known_field_names(self) -> FrozenSet[path.Step]: return _known_field_names_from_descriptor(self._desc) def __str__(self) -> str: - return "_ProtoChildExpression: name_as_field: {} desc: {} from {}".format( - str(self.name_as_field), str(self._desc.full_name), self._parent) + return f"_ProtoChildExpression: name_as_field: {str(self.name_as_field)} desc: {str(self._desc.full_name)} from {self._parent}" class _TransformProtoChildExpression(_ProtoChildExpression): @@ -475,10 +476,8 @@ def calculation_equal(self, expr: expression.Expression) -> bool: and self.transform_fn is expr.transform_fn) def __str__(self) -> str: - return ("_TransformProtoChildExpression: name_as_field: {} desc: {} from {}" - .format( - str(self.name_as_field), str(self._desc.full_name), - self._parent)) + return (f"_TransformProtoChildExpression: name_as_field: {str(self.name_as_field)} desc: {str(self._desc.full_name)} from {self._parent}" + ) class _ProtoRootExpression(expression.Expression): @@ -560,7 +559,7 @@ def known_field_names(self) -> FrozenSet[path.Step]: return _known_field_names_from_descriptor(self._descriptor) def __str__(self) -> str: - return "_ProtoRootExpression: {}".format(str(self._descriptor.full_name)) + return f"_ProtoRootExpression: {str(self._descriptor.full_name)}" ProtoExpression = Union[_ProtoRootExpression, _ProtoChildExpression, # pylint: disable=invalid-name diff --git a/struct2tensor/expression_impl/proto_test.py b/struct2tensor/expression_impl/proto_test.py index 251549d..88927fa 100644 --- a/struct2tensor/expression_impl/proto_test.py +++ b/struct2tensor/expression_impl/proto_test.py @@ -13,21 +13,21 @@ # limitations under the License. """Tests for struct2tensor.proto.""" -from absl.testing import absltest -from absl.testing import parameterized -from struct2tensor import calculate_options -from struct2tensor import path -from struct2tensor.expression_impl import proto -from struct2tensor.expression_impl import proto_test_util -from struct2tensor.test import expression_test_util -from struct2tensor.test import test_any_pb2 -from struct2tensor.test import test_extension_pb2 -from struct2tensor.test import test_map_pb2 -from struct2tensor.test import test_pb2 -from struct2tensor.test import test_proto3_pb2 import tensorflow as tf - -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import +from absl.testing import absltest, parameterized +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) + +from struct2tensor import calculate_options, path +from struct2tensor.expression_impl import proto, proto_test_util +from struct2tensor.test import ( + expression_test_util, + test_any_pb2, + test_extension_pb2, + test_map_pb2, + test_pb2, +) def _get_expression_with_any(): diff --git a/struct2tensor/expression_impl/proto_test_util.py b/struct2tensor/expression_impl/proto_test_util.py index 83b92e3..46d7795 100644 --- a/struct2tensor/expression_impl/proto_test_util.py +++ b/struct2tensor/expression_impl/proto_test_util.py @@ -14,10 +14,10 @@ """Utilities for tests on proto expressions.""" import tensorflow as tf +from google.protobuf import text_format -from struct2tensor.test import test_pb2 from struct2tensor.expression_impl import proto -from google.protobuf import text_format +from struct2tensor.test import test_pb2 def text_to_tensor(text_list, example_proto_clz): diff --git a/struct2tensor/expression_impl/raw_parquet_dataset_test.py b/struct2tensor/expression_impl/raw_parquet_dataset_test.py index 8eca9d3..e1c826a 100644 --- a/struct2tensor/expression_impl/raw_parquet_dataset_test.py +++ b/struct2tensor/expression_impl/raw_parquet_dataset_test.py @@ -13,10 +13,13 @@ # limitations under the License. """Tests for struct2tensor.parquet.""" -from struct2tensor.expression_impl import parquet import tensorflow as tf from absl.testing import absltest -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) + +from struct2tensor.expression_impl import parquet class ParquetDatasetTestBase(tf.test.TestCase): diff --git a/struct2tensor/expression_impl/reroot.py b/struct2tensor/expression_impl/reroot.py index 80e79db..0d2e894 100644 --- a/struct2tensor/expression_impl/reroot.py +++ b/struct2tensor/expression_impl/reroot.py @@ -20,13 +20,10 @@ """ from typing import FrozenSet, Optional, Sequence -from struct2tensor import calculate_options -from struct2tensor import expression -from struct2tensor import expression_add -from struct2tensor import path -from struct2tensor import prensor import tensorflow as tf +from struct2tensor import calculate_options, expression, expression_add, path, prensor + def reroot(root: expression.Expression, source_path: path.Path) -> expression.Expression: @@ -42,7 +39,6 @@ def reroot(root: expression.Expression, Returns: the new root. """ - new_root = root for step in source_path.field_list: new_root = _RerootExpression(new_root, step) @@ -93,8 +89,7 @@ def __init__(self, original_root: expression.Expression, self._original_root = original_root self._new_root = original_root.get_child_or_error(field_name) if self._new_root.type is not None: - raise ValueError("New root must be a message type: {}".format( - str(self._field_name))) + raise ValueError(f"New root must be a message type: {str(self._field_name)}") # TODO(martinz): Check that the "original root source expression" has a type # in (_RerootExpression, prensor._ProtoRootExpression) # To do this, we need a general technique similar to @@ -167,8 +162,7 @@ def calculate( _get_input_proto_index(root_node), is_repeated=False) raise ValueError( - "Illegal operation: expected a true root node: got {}".format( - str(root_node))) + f"Illegal operation: expected a true root node: got {str(root_node)}") def calculation_is_identity(self) -> bool: return False diff --git a/struct2tensor/expression_impl/reroot_test.py b/struct2tensor/expression_impl/reroot_test.py index 4c34371..ad3a8ff 100644 --- a/struct2tensor/expression_impl/reroot_test.py +++ b/struct2tensor/expression_impl/reroot_test.py @@ -13,19 +13,15 @@ # limitations under the License. """Tests for struct2tensor.reroot.""" -from absl.testing import absltest -from struct2tensor import calculate -from struct2tensor import create_expression -from struct2tensor import path -from struct2tensor.expression_impl import proto_test_util -from struct2tensor.expression_impl import reroot -from struct2tensor.test import expression_test_util -from struct2tensor.test import prensor_test_util -from struct2tensor.test import test_pb2 import tensorflow as tf +from absl.testing import absltest +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) - -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import +from struct2tensor import calculate, create_expression, path +from struct2tensor.expression_impl import proto_test_util, reroot +from struct2tensor.test import expression_test_util, prensor_test_util, test_pb2 @test_util.run_all_in_graph_and_eager_modes diff --git a/struct2tensor/expression_impl/size.py b/struct2tensor/expression_impl/size.py index aa430be..528483d 100644 --- a/struct2tensor/expression_impl/size.py +++ b/struct2tensor/expression_impl/size.py @@ -31,14 +31,11 @@ """ from typing import Optional, Sequence, Tuple -from struct2tensor import calculate_options -from struct2tensor import expression -from struct2tensor import expression_add -from struct2tensor import path -from struct2tensor import prensor -from struct2tensor.expression_impl import map_values import tensorflow as tf +from struct2tensor import calculate_options, expression, expression_add, path, prensor +from struct2tensor.expression_impl import map_values + def size_anonymous(root: expression.Expression, source_path: path.Path ) -> Tuple[expression.Expression, path.Path]: @@ -152,7 +149,7 @@ def _size_impl( if not source_path: raise ValueError("Cannot get the size of the root.") if root.get_descendant(source_path) is None: - raise ValueError("Path not found: {}".format(str(source_path))) + raise ValueError(f"Path not found: {str(source_path)}") parent_path = source_path.get_parent() new_path = parent_path.get_child(new_field_name) return expression_add.add_paths( diff --git a/struct2tensor/expression_impl/size_test.py b/struct2tensor/expression_impl/size_test.py index d3140ca..a886b37 100644 --- a/struct2tensor/expression_impl/size_test.py +++ b/struct2tensor/expression_impl/size_test.py @@ -13,15 +13,15 @@ # limitations under the License. """Tests for struct2tensor.broadcast.""" -from absl.testing import absltest -from struct2tensor import create_expression -from struct2tensor import path -from struct2tensor.expression_impl import size -from struct2tensor.test import expression_test_util -from struct2tensor.test import prensor_test_util import tensorflow as tf +from absl.testing import absltest +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import +from struct2tensor import create_expression, path +from struct2tensor.expression_impl import size +from struct2tensor.test import expression_test_util, prensor_test_util @test_util.run_all_in_graph_and_eager_modes diff --git a/struct2tensor/expression_impl/slice_expression.py b/struct2tensor/expression_impl/slice_expression.py index 92ed808..8e42701 100644 --- a/struct2tensor/expression_impl/slice_expression.py +++ b/struct2tensor/expression_impl/slice_expression.py @@ -13,7 +13,6 @@ # limitations under the License. """Implementation of slice. - The slice operation is meant to replicate the slicing of a list in python. Slicing a list in python is done by specifying a beginning and ending. @@ -115,14 +114,11 @@ from typing import Callable, Optional, Tuple -from struct2tensor import expression -from struct2tensor import expression_add -from struct2tensor import path -from struct2tensor.expression_impl import filter_expression -from struct2tensor.expression_impl import index -from struct2tensor.expression_impl import map_values import tensorflow as tf +from struct2tensor import expression, expression_add, path +from struct2tensor.expression_impl import filter_expression, index, map_values + IndexValue = expression.IndexValue @@ -177,7 +173,7 @@ def _get_mask(t: expression.Expression, p: path.Path, threshold: IndexValue, ValueError: if p is not in t. """ if t.get_descendant(p) is None: - raise ValueError("Path not found: {}".format(str(p))) + raise ValueError(f"Path not found: {str(p)}") work_expr, index_from_end = index.get_index_from_end( t, p, path.get_anonymous_field()) work_expr, mask_for_negative_threshold = map_values.map_values_anonymous( diff --git a/struct2tensor/expression_impl/slice_expression_test.py b/struct2tensor/expression_impl/slice_expression_test.py index 570a67c..9f87b1e 100644 --- a/struct2tensor/expression_impl/slice_expression_test.py +++ b/struct2tensor/expression_impl/slice_expression_test.py @@ -13,17 +13,20 @@ # limitations under the License. """Tests for slice_expression.""" +import tensorflow as tf from absl.testing import absltest -from struct2tensor import calculate -from struct2tensor import create_expression -from struct2tensor import path +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) + # For tf.Session.Run against a Prensor -from struct2tensor import prensor_value # pylint: disable=unused-import +from struct2tensor import ( + calculate, + create_expression, + path, # pylint: disable=unused-import +) from struct2tensor.expression_impl import slice_expression from struct2tensor.test import prensor_test_util -import tensorflow as tf - -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import @test_util.run_all_in_graph_and_eager_modes diff --git a/struct2tensor/expression_test.py b/struct2tensor/expression_test.py index d4ba359..0662133 100644 --- a/struct2tensor/expression_test.py +++ b/struct2tensor/expression_test.py @@ -20,18 +20,16 @@ For further tests on expressions, see v1_compat_test. """ -from absl.testing import absltest - -from struct2tensor import create_expression -from struct2tensor import path -from struct2tensor.test import expression_test_util -from struct2tensor.test import prensor_test_util - import tensorflow as tf - -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import +from absl.testing import absltest +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) from tensorflow_metadata.proto.v0 import schema_pb2 +from struct2tensor import create_expression, path +from struct2tensor.test import expression_test_util, prensor_test_util + def _features_as_map(feature_list): return {feature.name: feature for feature in feature_list} diff --git a/struct2tensor/ops/__pycache__/__init__.cpython-311.pyc b/struct2tensor/ops/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3217ae05cd54cb60baaf46700914d3c6f1d37328 GIT binary patch literal 185 zcmZ3^%ge<81iN26OxwuB!0;HvfdNJ+<1-tOF`XfWA(%mv(QhR~5l|t+XOQGCAN`E{ z+*JL7l-$yyqQpvlm(=3ylKcYw^8BLg;)2BFRQ=+TqSE9NqmtCT;`|~sUVcHbetdjp tUS>&ryk0@&FAkgB{FKt1RJ$Tppy43vi}``X2WCb_#t#fIqKFwN1^`aeF=_w+ literal 0 HcmV?d00001 diff --git a/struct2tensor/ops/__pycache__/file_descriptor_set.cpython-311.pyc b/struct2tensor/ops/__pycache__/file_descriptor_set.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43cbfba9d5ace85d4695d9809630941ef5a61762 GIT binary patch literal 6135 zcma)A|8LvY5q}i*g`zCUj_pKtotUxXFR~LkPMWN6-K1$8r{9*QO}YUF28|JETP`hX zk5p0{GAa}W<_2!l0RxiXW(-@KqT7%k_6O`wkOBe`2pBL{?1%lvAS+N57_hr1iy~#G z9YG!6Bk%6rdw1`1_w-dfE)uv_oc}I<+e65=*lAvoy~eZ8IYRCeg)9(-Qvw>d8dwOd z@(VnNeO?Q$h89As9n`|Bk%b6rhoCJi2+$6z!b-HMREelO^x3HRuCMulD1sVW3@g!3 zf(spr$hcf7A6LMuuk~dysKh`vu5|o}ERKDER*;Q@Y(hz>&(RLnMOKncV}U;j{GCdt zcP>g7NO!d?0-5KH7gi2|{Gp$e_l_A>xA6ADz49NxDNCxEUCe80R#A-{%@<9b zW({~(q`Ik(rzQAqNU}5!EORc$4b_wuslEyyT4EVdoX$vZm<~&4DKhOF*pLjpA*{+& z%d51hq&Z_lQuM-@3C5sWQIeQqNdr@CE8#~$NTDzk$#fH^(n5t%p7IJ}RP{Nbn`#l?n#(h#IEAGcR z@~h}Zpd2W3w>iqT?hUdfDgg*4e*M+=m@Y%QNlU7fUtIJrkdFQ4dP_0|mV%EXGk-J9 zQQTN*o(k|_ZJZ7I1Pop?&@}MP55JsR)>qZ3qOw||R9>5!Q;iQyy*PD8rym$aIj2q; zCN1U6Gp1TFblTR_i^f#jvS7q#ifeXX77T-(us0$&HoGhr6iroV0>&PMpK%6Ag*ai- zd&25D?T`T9%}8aX9_e~8P>qbhU5_8ye1G$OhbO}5ZZy8x^9-bj-Pty-k%JxRUI>;r z#bxW3*9 zmV@+mnWrX<6jFj4JoJ~#!EKMumd3Wx!=?*?Fj}}A@av|h%G|cco|c|+*l!s}%7N{0 z>wBxUpA%AjuMhlLP_32UX3eH%>~8{e~a_*a^5=kQizKAZmHT1O%0& zD{^kRT@v>S0tZ!%l4j-$ON?Dw)HO}NV@P>`Xm8MFK+Pq(WXvFvFYXv>&#vI8nCY*ud0A_+=m8oOEeDj2Vqu|S8pluh4qmvTUFmi&dh3hPS%VI&}gu| z0zw7?O@%zV^!ND+p%@yVmj6Iqn_us_ybtKtB%h_YPne|x(6~>QVAlaYLOiq(^LK$Q z@do*tP~-y`4`fc;Q5OUZ+>6ql`Qj`(C8w_zFRYLK!~~j*KL>fEA4r8%g}KK4;Em0` ztz&MtF{!A{LrmIXKtssBb`XcM`J!HAxMqi0x>9UjRuo%oIS`~wnvPI(1ttrOK3i-& z5i@N8%7i888+JccSLJ*GVza%RFzhfUPry)IYy(ja!xAp$H@o|Q$d5`VRs*^ zlO(YbHj>jxhQ_PGW6lv09rhaKq)1qD2+v>T{w4NQdnf;zt2aEFz6`#s33HY(_mteKyx?4G&%9Wb5d(yXqSfXsn5?YqiBKk~B7ad3{= z$GJ6K4m`TXkeXYczshP_s4%^t*@QzcR5G+!ELc7Oj^QZ zRhR@d_mX#$58%zPB@9=E;qRRoiFPsFP6rW>!k%kHU{Cwh8N?xGH!qwoi@wjH;g7x* zm#-~$xH5r=$zUaVaI66WyRmg}EcrFI$JN%?;wHGM$4Z~fFu2`^)Ekjr0Z;ua4$G~ zxM0%mq%sFVummM=^FIcVT5*A+%=`K2yIx%@E^HiJdir#_`rCE-M?e5;S-TM`hspss zCiOy)9z3^@U!Egd%Vj)DaqAZw0QI~w66!t*5!rGC)RBs$eu$@?yk0OeX`z8k42sUd zx=@_2-H|u4E*BgibzQT&=ko>SN>lW;wQJx;`!IqImG88;nR@~7eum0kLW0q0i!I;V zNmn&QCPajt@V?Mg1LJ`~+g{$67`8npM`1M6OCW)@l<@#aBaV}wp4&PHAsi5XJMr!~p-b<#>rKS^DA zoVv0jCM7-PYaYfA0xk@n zq6|3%l3B}vaE!pCODn3__GU+tn{f2l2Vc@M90d@xJR6 zfqzsMA$r$5(TiqytUAep8(C1L0m?w-8fe5b`riLP@UnH>X&Y;XjfF@h!GA2bR5o)} z!E*;5ja=xOX>{NZnl`RB{J$(?ZDjySfH$UOY$AF*2EZ7cCM#A56ktLv_DPx(OP9cUwc%JbrNXd}gH@jAy6 z?X`4|9m!@DJ(tbeVVEqqmSPA}1VlUT`#Yoo2FWE|U((bJ%VM`mi!4h~#5y}M59ZIg zf7aNdryaZnsKE{mo&`CJBVp2^Cy?N_r8rTF|6?)(n<|R*ZFYx;fj^HYE6JBbn62rwim3C3)a9zhlf6U=^^iN9o*6S zvE$nl*3iq=u~~LEj;M0`u;y_YZB)Si`ec|0@m;{oPQv6&i6u z(8MM2lAbuQ_YPXU6All}-LAd|!@oc3gs?q4y1j0Vy=o0#Vtt3eOAnDcx$9h07Tc~c T>kLm@!<49&|XSt>)1`9RYUC4QW#|I*p;l+Zf91c zHYupZ&>VYnjy<_4{y%z_gP>x-5a=nFB4eKdg}yh^DwZt@9ql~*=DjyF?{j`jr4j^2 zsrAqH`XC{H;UXRb9dP~!fJ5RDmw42(%G9;YzwJ8(ixPjp9i%9+%R@eMnZZtZ*iX0# zg9pkZe$q`EJP0`DrVJheJnD`b%m9zMV}OSbhZkHFx$hO z5R+bV*TR(?w+LuXbwuO{z;V1FaH_RoW;(ThZ&Rp9&2I;CUPZE6h*ly|QGv_^p~`F* z%Yu7C6=X4t0?Cz#GOzKRT$*Z@-%{lWb_1U5qY3N@PlTf4iAuy#k%wC=$Te7ej~5k# zN{5>?Tt6Rn)-2vp=o$`y!bF$PAv>4^x96>fr^5&cDw}CIRq32hn%d~Bp`KCOj8d1oLyn+Uj!nIw=m+|g{2JGITe(A@(h^4743GBiq6BuK z&(}uJE^D?cFR>U#lQAM!D%7+0L3H&dk3fK&+-r1EqTTxoy?Mln$;M@r31tuk>4Zd& z^&IA`UsjElL=SXgqq_B#P%6J8bmGC<#_HoQ9&LX8SP$ow5;FQyRJ4Of=#e&K?YnBe zter~U7bgE{OFK%GxAmyG2%3eP(-qn}KQB{w3A7h9!hc%BRnZTj@nTQ*ANDPY%>O^> z4p^Wtv^8=%z1WyuKV?%5_HKhMz4+|RrZe|02#{LtmBW6{{gnIVeq-{z7xctgYB)=E zXQ}D1Kb@(U&Qz1lG}z7OpElU-2D@|0Zl18Y2AiwTub!~gvjJ;hjb0F>t!vaYMK0}V z#yLigbDiMa54O0ttSFuQB;4cO)hrM9FxHuN+TqE{Nby7#BN3Mdu=G&O zOTcE6n!MCWd?QVwij^ISU8u)nl$-XPeeo WZSi3F$LvA2W_~T3IPZYKl)}Hj{SNp5 literal 0 HcmV?d00001 diff --git a/struct2tensor/ops/__pycache__/gen_decode_proto_map_op.cpython-311.pyc b/struct2tensor/ops/__pycache__/gen_decode_proto_map_op.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..992e7c353646f983d39e8a14b93644119dbd76fd GIT binary patch literal 695 zcmZ`$v2GJV5S_id*iJ-aMKsfy}-y*MHDa&y0oB(_@XD0I~sETU8R>p#w zLCb&9GGBTP?aKK68)Qle+=?V_Q7%Hs1yv}uqzO}0s!l39-|e$DDelCQaT*`JLxv-R z*HEN#h@qf1AaLCLPlcE5m7+>1mm#IarWS)#Vie1hjvDEn9CRXNi8zsXjAm0%v~Fjp zm}x^XLC+?v{U#z+$)#976q+igd(;r;Ota`H7De;cJc;B(6R0l+ipJs%eZPAN*LLum zw|-A1!dE;QBC?VHT%^ZZD&G=L6$^!*Y8-}ocSd^wNtODCV(`z+sL@UB$&AxJ)=e+_ zm1%0s)Fp&W?*ZA#;K3E_W%D|LZZ@x1t?q}1d8<3|wohN>-gef0_Srkgy@MYv*?d8M fg0&emrxjS=$s3O|s82Q?ezw6n?&QlTZV#iYKQD=t`mbch_M4o4`BDsdp0 zzmRy^`+tDf@twNOm{0@8>W5ZNjKVBeCw)D&BQ@z~$P0O*@EBcZpx*q~L+9K&f(26V z5Z&@nx}j1(7&@+++Yz(smV0)T%euE(O`$~T&ZCv{FL3P(yg5ils*s5mMTy9#$-Xql zR_VkuGnx-%Vl0*e`((jy1u1Qkp`5ICz}0xud9ld+5+j$#|Bo5EWgZej=377>R}f#p z(`s2S;aRn;R~x$@w`Lo=mvras&5Z6;@v95EH=}#MLbAC}euG~t=*?SjKc3xteir$G HU6<<*kom&e literal 0 HcmV?d00001 diff --git a/struct2tensor/ops/__pycache__/struct2tensor_ops.cpython-311.pyc b/struct2tensor/ops/__pycache__/struct2tensor_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7d3b4b0a7cfcee05c6a2e24e397f430e01816d6 GIT binary patch literal 14312 zcmeHOTWlQHd7iy;Z=}fOow{&DNu-vfmJ(&_7F{e^qO8Qa)RfzVvbvhx8InuxEipp$G~UNL3k90X^7&f%YK}eQ4j7#-v#cKvY1HhrF#|8z_Pn?f0KE zvpch-nz#+x01dr5bFTk6=Rg1b9RGEDdqlwBQvKiO-`DPFl6_V2b4Z}5VJa9H)oZHlsm2;AwEi8y7RVipiRZISY5R>qmkt?NBrI1m~0BcHkOn*tz7xC?$&*qertY{fEThfbaN>g-H_b(^~)MkpZk}9c1 zy_m{NrBt!BWs_EtR82wCAdSLQ+0;@oTS!TTl~lGMXETb1#$I=0OM_Z2q;g7OL0?Rz z73@>Ff+3OSy3A0fR7&Y9CB+P?ie6R=8IMg2_!9F%qzGS3gUvn?bVt)tqu)}41@&OF z!;;_5zhqRP(e518$iHM%+m@q3jaoZ{uVaN_yV|uvCQzy6C38i$92INS+8KNubG6&m zuDLoIwRXlmYMWi{8f6Rim#wgC%Td1y^YJZ5bvJ7548D$i_1M*}6*fj?BdeZ@ry?{N z4s7PR_uA{l!fi#>HBk~lY9f)Em@SAiStTc*Z}5kx7e#$h5v3B?qNHXeT@hvKn@seo zggbLv%Ludm^=sFzf9KkT)TJxe4^Nw3eclXY3p%pSmLspQT6{X}{%`XE!7}gdsz*Oo z@9cm6>f7huykNG?=Zcb!h6ebVH?N+(Zq;Mjm_5O{t5@GN{pn&c*Q~p6(Z5rW>C{r$QyH;J@WE}+1axfFC@Gw(SzBR%cfOHUD1fcMDdqzPc0Vn%2Y|tmsM3- znF9Wn^jx1Eu~^fbJsxdXvxxwIgnaV^b|vPYF;hoQ?^V{ z_ha5+{L}s%$(rz0P-yEif}#;je(L=sv=QvwT$uVa{>h--vZ*j+w-A7cT{!SJ>eD;U z!d<)I0&wutkxvr!-e|jjZLB^ZM7r(|7~!F+cjybOPB3GsYa}pbMkr=vo+vC;fZ#Se zdEO!R0=K1H8FCxqQn_me)7b*hJY#lp#HEV!Ts&*$w&ox%COh(qrb!D*3eq$$>1I4F zWtMS>DWE>ZhnvHT#R6nLlkkUuWv(xv1E6K}Y(~p8WFOW|vWv;sE6Bij^SmciJWI&) zzw^Q)WS-bXVKtyrZ|kpuzLh}5Tk%!=PrZ*IJ)blrDEQAOg-~Q@0R!f8VsW0a8N&x^ zbZ})(xvj7UNwl_>j6_8766>jgS&4Po0P#*%Uqr#`#&~&B%<9yI=_*;2s57VPq|SJbJ} ziVmWgS|}D5a^P^X4X8_V6<}00g`50Nh?mF`fVkB7!7`$>kJ+Bt+UyBBsA54uv$fhW z9mpLMVgHPA78v9B*i6%e5{^b2uc`twNmD@3u$TxvY7vK}DG=Vd>^&SUo1McGK?|6h z(TN5YCDbh|O34D5x}8<-u;U>({5(Z)XjW#pC&JAvW^*#kGHc-$9kVHwyR;A@4A%H7 zcn?UJ7q6>jWm1$1G8H7g(!=;<^RNhS13?zcTC+8&zgX7UUisED3iiu+THK#aDoK6@ zBzS1A#4gqV3($xSc!Y310u(AZI(bCQNoggQpg}6xlpqO)BQOeQEXHPxBc$9NX+^W( zElOCdM5rn&3{{GhS%i8mVVM|+!?Bq)NepfG#5BZM*7<=XEO2+SD$qh_AS?4yIj38Y z_a9`-9Z7Je#Hz|TL^|5Hn9XQ0VlJ!cM0zB1v3#+1_N16ApG|VmirsXsPrHoAE{y`b`FaVQjOYpov)%oC2Dd z%511(348@iaR5L8n}G_M6tyxiuP;KRk~={pn*}PF9&W0W8?5t)G-YIW#lBCJ%nB9^u0+&v8nlT zE|;=}71J@vleuC>%4w&Ov3uXbXpNlnHKE%18vi{E-0%Hxs8LiCXof{}7$5Z&mN$cZ z31&B+%Qa0Rs2-%QhbTd$L}-pQqfI){G*XmAOrpp(VGg`(;p!RmJcWPSB_wM?ea0hn zo%LXakB5!K8GP3}Uo$#iTbtMj@2Ca)YQf%Gu&WmAsSQrlMlaOcgOQ7#7XqbUva~J) zeHT4!W0{8Y1khv_(7Al^2npP+hKuDi#4DahsDILEQx8>yrVF-2uA$Y9x(i2vVyy3K(rU^rgJM- z$SCV0qf?dcG{wc9IjA!m@+;arH8*l^@^}vV`7^w?>-;r8$$cPRb(Y-&59{W6cY$FsiLVF&5b3HU=gr;8jJh4Mx3P|dL zCvd196k^+rXyWm>5q-7lf0dn?dK#cJ4TPXTo+3NA4jktXJRb=aujXA4W?ks6c-?4S z_5RS)^siQfdgC~-gT^ub(1US++6;89wpH4;K3pOjhou;L-!t!lJsGNaH(8Nk*$X=} zys18-(_XAU$D#Y(Ygcen0IEq2!M)nU8( zE1^nIZ}^Y`@aCRJ@3q&mxvdg`)f;#kcm(hMNyEc`2Wwi5uEsELyWR*a(&$RK60O7@ zfi0diSmV5K^T0bOt#+(#TkTx!s)1Gyj+NG#JTcP?J5t3~I7I{|q(&GpXaN9;eCmLoZA%mo|ry-90O)M=>bjVb)8+ zryfU1C5wmIJ{lPVqPWKF9`ze2+#B1%@ZHM3{!4V$$a4Nyn5}}Z-``B0M|!^U32i$b zTzx#Z9z13QkJWvy4%|L%*qH!A1`6Ba5oD|l_PI9*|DiF$tq`%%etV-{7FlFLPS|&B zx03+rwC}#&KL3hFlXJkf8K-Sp$@9n%ZF9IfJY@_YTMr&Lg2%B@XBTeNO=0xC*WI#Q z%$MQP!r!&SBRpQPIXuEb8*1@(R#xOhn;CB!q$yl7DJpJ@Vm4@Y+Y%|1^C_h>$954TCCpZqt|cQ0EqgNWEG1z8=Z zRmB~i7um=y?w5L*Aa^)$14~GKlPbGf?RZVSO7%hRjhk(7An3x2ZmBc`CbmL~M)w-x zEpq6bf1GQ`eN&kahm7_Ypi$z8H#|b57lyvCqZVynyM*|5SO5KY*E{zZoqN_MYF(ny zHNG~v(LHK(pNCg>|6RBTJ$vCE^z2_a9yFKW_9NuX=mxeis8VUVA>U zryfA{3*x>nDXDt9>mi!}pYNJNbDWyv)ExJ(O@0*;f}!``smAxN2lg3(eRUz=+fj=S zKKS)Y|C#WA#2E*0DCa@ygN9qP1g} zYR69hr0qq(`|2h1qO>lAd_yc*4<1}ws>VkiUU>W}+@?VE2baHpc_Y+Q>)&15d7u{W z$FCmthX-E>NWP@xg)a~o`cgoGE!s(k5b0byzfmHD znVd|g3OB_&iwffN$ciUUiSJ-+3>1gMPWqgZdQYRa}z~ID{6)s{G8;NA_{0aF4W8zhtI($0&3bd!W`H{$F}w&&>9gM z(*ZnGv^%3xoIPc!F^*2Uz)RUjt5V45pJ*tlBar)$uxsjZWqs#KW9LcSFmA*;K6vAP zuDbVRwfEFd&Q@bH>#-RlHnTRi5srQ^TybzE;MMTb!VmV_8va~DE zhlm4b1XAm6MrfK(bhqawauQ_{J;gXHW?+~gKwZQYN+a+riD|Vc$r;=zp#esR^TJKW z#ZL>(GD*gf!9^B{FbZgsVB^A%XpV#oV+0=4mJozy0_#Q=F=$4tX`-mF4Jd5^!j}L6 zl1uh++dd^hZjxw8r7eNSFZ~WB#Tmread)Y%h|3VLcJ~|02(F(J0hEuTGF#+ylHZW9 z5!p#5k8veJA@x(~8%v0$lTfAv`&B@v?9IY5M*xEWisC>e06mrE_eYIcPCAIP{2(CakRqD1Bpj z>?W-X3LpkrGrOAahya#Ms*k-)0jSGpb{XJ#RH&5vf?6n)2EGt&HSp~0diaPDK2r4_ zVI=+=qwo8_7JX~ww5}?2Q};4DpT%WfPGUm#)}~G0n#2)jloAPjP`zA))}|{luEE7y z(k+JkTWNe%y4%g;Mqe)$aMha2tpL>|Q8e7VBuOdYBAH}4F2%%6Gm+M}qcmTi{}_?{ z8tnBu?r{)3*z-5fJroupBX z?34Z$$Kce(c3jx`4C0#gmIF^rOClXIo}9J_0{i&UV%cWyrxxs5Ia4g?QnsKuVcQ`2 z(-0B1K>C0cX0T$oHt3a9(?TVwzl9BG$GA1b5^VoSvR@k5aA@(g6N_BkKuMuPLoLCS zm9S8|8^%xzUN{sB3pPNYij3H$1MtGok?T%-gp&_=6rMdMaoqCTi^T|r)D`F6IWPzN zoWN_770WM!JQq1rk<3PHCs!;k!&ohsAlBKMXgSvVCtBK=y<`T1hLiQafX72klL3;S zADQ)T31FFh%w9y<-7(Hm`UtLM^HR<|3tuC_r>BF2K?{Dd0F^UxaD=I(kwf@0l|;q{ z-oC&|=WadTXXj@i~e_LBbFqOvLhvoe zvD^Xz3cQuQ7Ge9v(*;%28KUTlLTu1Ik4VYZ1C4L6qtft}8n1ZV9f__uVPZmN<1lEn z5^Cb!643|?sF^;jDAA$*0YsU4krMJl)a#UxWHP&5p)exR+tVuDE~~#!$$m=cf|JGH z)VC-%OUbv9B*JEtMp%4r4%mfOepJ6pLsUvATx)ha(3FxiL=h2j3$dlA=4I|@ceuGv zqty~ous`nkY-3c|d~2@0i$3_@GX(7)k%2vTOc>nx;Kuqu!Wc+Yy}h;e-A4PtXZ=R| z@wNRw545?WP(DPVHbt6#zJ1!*K7;oT)$WPvfitz>kP+Nl9X-Ag3|AvN*Mqx^;I5yF z$37b{j(>Aq{FWhpYa=*l1drAtyB_vFZmWt15mRfMhQri0{m;{{e?I;C`t${3`ofE7 zU;5i_WH zSS7-F3NM83@Nnuj{*W)p#1eMkEgyv@Jv*Nam4F9$E(7z|N8e_*j07^YQmwa`&ScQt-60()f}w~u3t?!Ty_7p zs_Cx^*6+D6T6O>G!4A*%hgx00=dt?qF3Yqg2Zhd|_wUvG@%y=luUC8b*ZTH9-v8`` zarm+^ex=%bwbr-uf%>DBYH#B4#bw6d)D*9=j%KNVx)8}R*wtY_tcWp zwaG*E=(y)1f}oVv`UXDC)_s(H-aBmcj@JXMF(~v7)Z17#Bn<9;82|C0F)&dNvzmx7 zI&J}p2uII+Rxyse=>*C_Tj`u9=sEKEgfa0NUJ!YX)J79^AHJ;MNQRorYk$k*>8uNs zTFaxXy*$jE+l)gQ00Olz{E({tw8xlb>=O~{5 z89o6rbaK?`XlpXs?K$({tskKgna9&lkF(5Y@xL2nnbzcpFm;%YVuG+g0T$W5&lsAb ulbooBY^327C+aRlQAJD$cD%3LKlkA!Mo&HaA*_r5P)Cgn+7=1p+5ZH^Ya|Z< literal 0 HcmV?d00001 diff --git a/struct2tensor/ops/__pycache__/struct2tensor_ops_test.cpython-311-pytest-8.3.3.pyc b/struct2tensor/ops/__pycache__/struct2tensor_ops_test.cpython-311-pytest-8.3.3.pyc new file mode 100644 index 0000000000000000000000000000000000000000..724d8e9a09067df227791df601541bab64b68d12 GIT binary patch literal 49897 zcmeHw3ve9CdEV^q?C$L1{eBQn5+n$a1POp|9$z5AH*pVAmPDPjx4~CH{!n(@aQ4IOF@MAhJ}&Pr4aMu;SOI4r^E9u zs$TEM~CeEst3Gre7?Be>wB1V6c)w_^{BbK@Q<81iuu& z+KA6W$ft%}=2NW^F0La?=~Q+2t+(Z(=LH|aIT*ZDD>l5wU0NqL0@jI5fc3Ges|`kl zd%2UG*!&tNu8)PUHX6RMimOeA+wy&f=eTdS?ou<#vEePxrSCsme@wI^%~j+$4Mn#E*84F$tZ%eJorl93+%sb^JK@xSUeGLx;!Q|#S_xR zK%y%V8;y@i9b@D14u?;Kd>#yF|E1{g zL`)0BhepPSV{Ji+YJh zZTQVh+I%)&%GD9gEQaQfjf^L*YeDuC>C(deV~Hyf-7`Eo78!dfCJjf&fzB+oR0?!j zs9ze3iYR%)krfqSfGX9vs4}QlAAJ)5B+2s{H^m+2oXu_8ldz?^YPu7YI(=OwEbN-{ zyn`6uHe*epq1G6I+*PZK;gf!wHFgS|_6~8=zA5i`*Ob6SG_l@E^zO1@1nt9K--TFqxoMP zyO5BEMhD;ZYTk|=ntz;HZ=4F+)b!I!+pmm`#I}!%BNLJoy}tc;EPgF9HopDkG3i=- zJUS5D?&vrfPw0P6AL@}2YGWPa*MAy*YIq2(ZER%x(B!TNdL)|N<%!|p$Ve<6j}FEn z!?BlQ!;xq_A`T5CI-VLH8;B0a55Zw1dj^G!w*bs=a#4?d&&6i_i^Z)A#jQ$ln_Apv z20WtLVa>2DAW=c!Xd9l28p}bWoZ&SYPth<$Y~l zR(#er+dNVSy#|rSbgOQP48{^}z1)2HLx?O=Dx#q?+`UlU|ENDSKdIE8QR~krp|fh} ztn59jzXD>Imr=||x8d~$7vnB@z;J@m1e4?wy@0&v1N4hLU_kT(2E_oNAO-!DGtVE?%X@zK~#KJs19=}KfusN0hV*&ozr{T~WzL@9yI96{FJ%gUx zUgrl%nxlMe*Ve>svZ2WEi6DwIj3=Vp43v0(W$bTs>5wvQ;8krN+x6@aG z>N|b?TAnkP<(`gnv!2g(Onbyo(w2vTiQ{hMU53!{nlRV5Kp|*yoVbFS0EOe{%DN zD7{rk&T5@kiLGPTl<*Gv#cJe#vr9&Rdxk{z240~;}# zS=UrxDty(fS*!T)C3ciuVp)H&DXwiWL$u1+&__b`h50avD zlTC1Z-IL}g!_%H$Lo>8q3w&eV1x}(!41mXOfm`zoX`#zfbZ`V2#^dxAnwn-9`y#dC zVe){Ju5~a`2t#pX>~dsG5@V9iAjREyyO)32{x<2byfQv}WcXnt00dSq5zelc-Q32Oy^)6OtQ3qqZ1=pQF(e7CVIyqOX`R;&t>U6@{AKpMa*ydI-)bd z)wx>CWa}!M)-WMI2J(hx^kbp;W?`~7xknbZD8d$1*s>@bTo4W_!eLc7d~X$@ep%S0 z2%A)4)1t6AW ztAK@j)WSV8vHL}pOI#pW%)pK2R9^n9=S#b9>`rZHo7*j~?~=ofsc_MajX49B{9JzN zQkbiM#v|7~O?RN?er;p2{8oLcrXy8VKYQg~S;gynZ}q%=SuWeAlxVWSuHB;69!?dN-x&P*E3f~uTD9X&3u3O@iv;5;A#xDid|7cb^FF61|+id&M;e)D;?xa-c*@0?5(A6YCu zxlnvERbKJ>xn%1cm7}B+*h@Me`h5jm?{ff47Vr_uJe2LBk6xq=--LagZ2B2M>ms`I z=!Y0ieqOBEZQeO}t%l8;L?Lsqdj7{_^#~)_9mZ;i!XOx*H@(wdF$`B8{=q1FX+MT$s{O2M$~)!1YO-L~7zSTcV9Xru zZ?@gM$cESaKB)&7rDFt65I9Momp~r@61Syi0U+mCBa7*33$Xz}x=7K8{4GAl=Wp5cBH;6f!Ul1z7^h&~oNusp?m==X+>c$R#eKMX2V)0hm%RjsLtb+f0I z{Jy#XdSqQd=Bp9bu;td%b64k{Pt~{2ossMIEcrZj`#^;Jr#(8nQS@U{bsR|L72jz6 z`rg+MsbyQ|#!!Z;!=&OkJTr2?s19n4ASECOdMV^D3hNexb#mQ~J0;(#Q|k`h_54v# z5l*SXDVg2Kh7yFVe3l^r+Ru>$pv(T^5dp}IkU)eteINrNA^}7c0P#e`fWqGLT=t{; ze)1zA30pxDsOy_V!Rq?cfvJGB4g|v_3uxlz1zvtR(y;C3r?N=HMd>t>XOIOj#f%Ox zHZ!2J-IkB^B2uhK9vEeKfp8N)1l7KiJV>JyWRe7}&}6~i2*(xSgesgMYI4#8 zaM2>kr6f06>WtL4!h-+rj$P3=w z(#SiAWX<%0gU8%q1o8Dtgx)7YIM_Aiv6Q4qlP5Pq=~whQWf7>(R!)<_y0A3nuDSj|R zaZe13DA*X7AWchvz41-?i4{6f|>6|~f;Af(~9JeiWuGAZp!R9XgD)RE!m_&*)iL}WQO@J_|>A@zrG@lr|%-9~sdv*-W`@0EkoT zO;4NZddD+vQcuo_KPbc^OlRt*hWrc4pc7z*TjIR6APfdbZQ63fzbG^=2#t!+tP0Ib zJ|l=lV*S%tjX*`zHIhX(36ZHKS^W{`s0E_qojuwa^3x11UFRA@jBCVN3Q|I9L7fxbVK_Rq01N0cJh~)$J?U4d z1f|v-c0^AXk5ey6GnK~NbSo1*SFX^>-H-PnRAd?7%!Y(}`Q9n1ziy^CRZ%1-|aq1%+=``DU4K&gETLc5`^KuwkLFK`Crf3!7$661tN> z=Mu7EBXW3J<(uvmRNVZWQc$lJ)X$tq6_(!|SuAW^C~Q;;o7KW*6UV*kee=8SZkpe& zRG(3+&tTd-RnTzj6}4dV%!wth(JoRu_RLq^ZI!p5VQu9hwWaq7$b99J!1+UqeANP9 zbt~{@esYNZ6n=-w?~wT&DPFL`wkswLk zHtP~I=oX{oE2YUAF;+!$y&?~aPu+-JjQ*P=(yB0ELT$zO$sULMnfn?|sJ zVF-)Un$oij@eYf^Yn}nkBR=>%UQ6@Dbhb0S=!;P~wwjAT8~+GhnCoY7fI*SLSYKOM zr-GzJ)QZN&X;LLR9ML&sorQgsf(p$bozd-vXC}GRb!ecM@lFC=1iCGnmc_^JU72{L z7qb~tXX(nAJh^6S=8ouZ@D_1_k+&a*OVx0l5^hk#4Kv45we>T{RK6<32LTuPh6TRi zR$SpXsQd<*-;l%mo;!}3I>J?%5xOSKsWgrsgSvI!_1(Q-qi^>M8sx?uYC$K+qMbM; zhv`q@yH&nh=DQtX@}@)jt?-9c{;+J^+2o5DG9o08c7n-UE%^e$G6_V|BtyVrUSh^t zJw~urO5;Z4OmdhbpNuY^!4|uhEt~jwtQCPd^2wNsXZe_~*iE55~GFFhVI#?-9er7%iVyx5G zRF1W&&Y^rY1LEp1tuEGJM*F}j7A|S7`!LE*zf)eZz#8fj)mE70ivbD?@m+?t)nKg6 zfe|ZX)_}3tD$Nrrqa9Gu#LH;Mn5&FsYbs;;6ECA3W3Dn{+r%2`vGR$R z(T*`!8P`^iRZqN(c8s~oSiPotta;*Pv}4Rw#@g&MitE72)=eH}tnB5Xmn=56>E)rs zl_to)()FhCv3Pt4%WazsMQjtBdo>?Bq-l^6jfak->p>Ho4|>3|NwegVG5$X|@;^P? zgD?DPDfd_K!nJfb?ck229UmAz{1Kf`4(Qt;7@u5>B>FZ6S~AAAF|-BP!X)JQh!_mY zS^5?Q&JnO&OtTx1eusR2mw=n=Is&lLoysdE3AnjPzl=!#lEU92;Nr^AH}BqQQ%GMP z9>ZFqfb@Hm_L~G;Ts8#a*aWH^3Erl-G-$cG`ra+Ha?S-b>ebze}#)CqN2K ziB=Bl!U64#koEw)>(g~1VOqMRGvm@e_{WLMfl@=CB|EgTL|~><+ncwX2hSH=><=bSw&4FY<;oDWdUFO>z-lD>9QTZ)0zr}&% zmz*Urvizj8Bu2T{hC8Ex2JQ^+6>wWw-YeLa5=!++1VyM*g*v9qfrnh6 zF!(+JD0K>L>KYP+syn~CQLWx9??0#C3V&YZ&&&LIn*?Nwm#RK1vl~+b{_`|70EnqX z!26asB{yzYt9N|2Lap8-?>npC3V%-J&&m8b2Vyf`s=8ZdHzD3dcy|%nT@I{|QQUoU zHQj{u0AW2qSPwX`nq^phSyxvHzc9t!EmzY`l|ZZ0{ClYq_BtxzG0N(Wm7!)JTRbKa zk4V@f(i43w9&=m{@ykt|{8&8pSS#Uj)Z0H{@pv#p#q;4rQL#`g#OxdxHv9b7Dur8D zmQ^gmwv4r4ELn-IUUM1k7~RY06KPIz^>1w#99>2FHJ8zj(Y=h==CYS09mstYnMi)Z(z{3Nqd$^XUT1r0A|TpY*7%CI-NpfPGPD~x=lH~ z2jG;Ic0HD?Go{;13!b;utk+dIGe?jCvn*%hRe`R4g&=PMQdh{_+4`6CWgrWi55etuxSM&U7y zbxh`uIU-s4t*TJnclr4(3V%Z7PsqG|!pbILt*%U)Fn_=AjKZH)`Li;QMF{4^{F-a( z$rWnDe0;uE;g75QahX5<;mYKjZ&&zUmG6~}J6ryMc0!jw0+=*et^BdvN;A=ib!UW? zU6||lSTpFDPxnLuW*Am(MIiO)7uHpItaVkGl(HsQ%r#v`T>5pL05qo+lJ05TCbG=+ zRyoqRanR;UfE;PuCYa1sSUJ+TS{>K0ELKB+6&U=UzVScU(jL+4=q_Q9e69_wohSR6hGO-jgVw z9iMyog4yL0t0CpoOzvO3lrt_3jSMAdYK~31>2qxl{{Ck_``Iw*HRX>(VF! zO?#QE(b%+?ErLF+)z?;So3>)PQ%=%%2z-~oza{YR2+R{;#d?)ouMoIS;AH}CGg}b~ zpqVYrAHNbEk7>L(1iM6g$omBVm@w!k_rC(r1k;|*K0L;Wv!XS;`h3^cwa9SpMjV!1 z`)kc^ohOOyRX2&PdDF}p%x0yU*3)l&%6w-2w&xA$a7=(v{AK3A3>he6vMa~zkP$m$ z`f7#P8Ddv zIlp3AcFngce2>ca$b3(Zj8;*WP4m_6SnNXPdM<}1v_k&ZW7Q(1P}#B@mNxR;UP z=Fmx>gbbH3Wv`qPft{9vBK2n_!>j>DHHx{0b^cH+_Q-$aQN)FuKx@Z6YwP^U_f4<#?OGMk6h7l$iix zJ3OSjl*nbJ1U9O$(c?u5XJZD9!{(6Dviy*tuautbUXzwjPc%V81nmKYtvBqoh$saO zYC!`CRC(3RiPz4j>fu)TT1?%~y)<`4DcGYH>~WAU?6Rod{2r#rx8G^L1G=+I<#);a zE?eiZ#b~($rPE%O-z)QbZ84mxn-#mdoic$?;?C|~)Rd?S9}Ry5-OE}R48xxR{)}Y_5Er@~&p(sHJ2PNf^OfS{AsTM?Pe`?5e3sVOAYY z6$>-mVe8^X2|+Sona{P8yT|6mTwTQ9C;fY%)jP>ua*^h{tw86^en1|q>#(49dfY-_ zJpuY^Hgg@6HM9Jvb69C}9n?p(V%_R6^)}sNs+PxC*+^mn8z;1Yz9U~6K-{b?2)bD# z7duYwW{q4T{xei_oLI^gdQDZ$>z_}aSZsM}q2(#1<&fHPNZxQ*sXC%oLG4=pW1&VB zw#glb?se?GJ*IRVR_Q*I+I~dtI7<6uY0oM_TJR8<`3&|b71rpz6TgLx*#2ayKZD8U zBzoUQmES1y8!i7pvLxwK_!gCKk&Szq-hwrCYZ{EK75@p;hH+UNeXI?v774dhi;mWA z_9JKC!CrLO6}6>t?>U~NZTK_Fk!fu)s~`I8YrC(Sy=2PgC>iQRznmqzYW57)7qa^y z55p!AgSmY%Ppxj7J37~`@Y_{>yUcI5r~wP=TikkX zVe2_%YoEHcPu}vh!at+(&&bC8(NG7ocVoQcQGe{R#!(w}Fk`x#t#bCkG(s#|>=!gUdI!3PgB(bIg5=VFBS2q9D|1Pt z`H)HrrisCtdb%V7X=b0U&*;-sLhJB=$*s$GZfES@re2;xCRx4w&~dAm*Q!DXEMQPC z-}$4R2kuq0suc&6ii2v!LG<$a4b;nPAG?>s0>@tJ<@>3V@BhgAIPr@Z!MH~a7%y^( zqu3ebVaH>58T2v8Gw5e9@RsLx@Ev^fZ<{?5*8FW^oy8e|b2jM!@-?r_7AgvpNT0<& zn`1TlxtSOEXA=+n&2if@NQXBFV#k7Gte*BU^A6dD9eA=sJ(cb6pW>4y&u+^w!!QIa zZPmh*1m`|ZPX$mG*w)EZhTmK_<#(S1bk+>8?_;*xgVSVy!j=1(66~ff5*Cpqvlkh* zfSom*uJsJw;pTouAM*R!d=H>2{@H{7OW+>?+I&OFf9~OijuiN~q5HqV(vbnQbUEi1wZDA158G-+s(jX40I%cei(zR*=6!O@w&cwI>1vK_o zmHrYQ%W1V&e8>Ojc*jC|$24$xM|&`PM{IsUt4<%o!d%3P>+=-kN(-u5(Wz+s8P=>$ z(2*1I6-KE4Oh{fJunR!r*`X6bSlf>#5)us0^^VZ}>?uS6%!1=To5#i=kZ$pjm>Ipms)#pi3=X_gzfO5PpHUyp97dV zzf@{8;QMuLuvl4ytN}1w76oQbE(vC%V4;N+T1cUV0c;-$bEP$JG<>Z=DQQ+qnrWH8 zzD?tXKNScs25J@pHMcg(&E5K~1a_){oie*Iu>bD!zxDjT_}rVHL&Nj;V%G>kMX0$E zp6xYL!}RaHaN(C{p5osYQpB@L@H47!wcSV-W6xy(HWN!di_5eHO;-9!XS_ zjnQi{D&j85idm->KqvJf#%V~SL@%5;m4F7ZC-q7X{(rvGpVKS-jqyrs++tQ; zqIRP$s~xP2fvJc!gD&Ifr%o-yqf6JN{2$78fB+3jnr}#ah1tAv8;{bvIv@9I=GfKw zxaedW+fL|XTV;bT5ErLlzkH3Y18$NS9y487W18ms@HmHFOGP_eHRd3XPt^TWxud@_5 z)5zm*vyXo}Z*59$R05mTz-F1tPNd2){_&!J&?znEEJ>!c2nuf-DXn}l>p~+myI`0d+hYv# z?97dGj^HDJR>Mp?^T}f{pWt8?^=7oYpbxW%R4^Tskjn%Q=$#*sAo_r$Kj}FDL@y_I z1%Y2D@D&1IB|svJ<{8vHS00eGuTRi0FoWDq?kxaxh9X%@ zrT5pGp|gd#^9T;2qqF1E6EbAZ&)B0ay+8?xV7#lZj$slI%;M=s8R}G|z>z*=t;d}zR(t96O^*VOZgT!E{_X$lAA#y$G|ed&9z$YrOLvQui=shM;43W{!i zMy@%k-%3G`TF^6df@xt-riAj@%irjK^V(e3+>3Lq-!J`MwcOC7H1wzqJz!!hj-%TX zoH=jqMNJ8iwCYpgyc^HnsGMz{eRj6;R`acAZ&hNR=GkQBT=U$sbCsAeefCbJzVQ8S z^WA6fR>JT(ctbym2DRKX%>DOz*WXL$LI_ z%k)GqW>0eU#0=~%J<;rOMv3SbJ<+5G*6gZ?g$^!TgvgjC-^Y57-pi;DU|=*Gz2rf~ zvX&29ztr+2FFdTaE&U<+`KX-{4~$9s^g=JU#>)=VTutL;5g!uL*MKzLcyCGXmENwF z%es`ZF14&H-D+=@zTNfBy>h`;rC_UCur<5YcFslTHhtgwJwa|bsx%x`8;(+|WfpM( zZLRiJfm*#yt?s@PoiDkArCCp@{8O@V>rKy2`sj68)B8=7GZ)c^2_r-^E^ERttLfx0 zt(2kEv&eKt<5+GGa1Jd@ZRTgHIs4lfMZ{G2l+T$0hNE&3{lFBT8P_ownGKY>fHUSt zT@E&cc=e`jr{b(Z>omM1_Tt1)gpNNVlFSYwdI9x-Gl)$8=;(DL%&>@W4KbC7ELBko zI>;#Y3Qng@JNmQD;a4&hy@ffQ``X2pVXbEm?0seG5rcjfR-qd9W7NBwhWUqc>%=t=xoM1pBl~9Lum(E-l%Ir6#$033R<8 zCJLOXmg!YI;$uo7Gd^q4fj;9rX2!QLmc?>0E+~CoGxN2AXp;?^u*LE?gLH<7*dRwF zhM$t^IzsaiibUrSp^pvu_<(F|Kq31nmDO1aC z_3T-`tK>`=#7oD&jN)J*Vcq2ES^T}%b$hQY>{NuEs<6}OK&~M*?P|>mZGX4iOt-@C zQTaVGzsJ#$G6NBgdzpUl7Rs7yL|J1$Fj)boXpkvF0 zj^mm~2ccE&kAVY^u@;y%zht#Q=U!r+9A_nUC&yXQ3Q?dqb+$A)mdo4p`@5yzsg{NP z`q^yMsIbUs%iK|gCQCbJY}-~n^12=Rt?->H-zgh+cC#=Ujo(_fS>zyV8J`ZM@d-4E ztmfcFbMQ%J)W@k>pmif>C3SbA>?T)O$2xmfed@=oPi}GR;5j0f(@C{he4<)^Rvr6x4Tc-&TcD!~ZqJ<-%JmSP~4DS3$gibP~bLk@y z=^8SX#t6g-kTGKECV}53@UIE{0fFxk_#uJ+LO>z#rv#P={2hS+)l!H6ePpeeauDfE zNBm-)L0l`3^gjEE(`TMO-50sod*Nc_*wG8U7o;LeQ9_`KKn;N=!bgfew$DO3LM{SB z_>W^R3!19&w-jjg;Z!*jJm)F)(b-fs*i!DpkuxUPTEQaNU`M%cKSfA`t))Jkrp4Ld z=B++hC^EtGo<<*?ePe??o{hd9c-i1H9>I4Qr&<`GJm0r*i8Dc6zVBHiu)%Dfh7CTN zRl?%vtP zYk~A6FZ)G|ppi&-?A|#5yyiu1impd*A)Cx&sc~#HZ;7aMNOdfbIVs>ACFlt2cYw~C z(zt|?)q#Zv7GVV}{wypZb`*A)9fh68j>68TMG2P*GO%Mo!ir_g%fb#??Ijz2wqk&fgxz9W50lBgK3{MT^g;t|{@vqT(6zS0rU_Ji%M)?umk=wM{0B0}_9 zD=VgXhr-buXt3my6gBLE+gz>zH!0)TWnX~moU)>!rO$o(VmFee4-T%Rg$YR=R|a%` z1RN*~!zenvY(`9+rr3fX%`@Vdsc(hf|({{g^JBeMY~$jZU*Qy0V%k)e4Ua09Nabzn`%-kJRgh4aRSFD zF@KK6p`W{3lwKP_Erv?OrX<+9P_0;>JrZN_0cmJFF(zsG#uPj|c35Z8wGzY5euT|4 z0f$yducoJ_5lPFt7P}sgjK*SOOw=k^jv{;g_(Z>vnuItE^M?ATIJ1%_H?FEKTd3Zq zfD61jwyo9af^E8{cJ{)dl%G%m(`Pc3tHK7bR(BfY9Y>%%tnZ;o2*CS2$KJcZ zaLQVKy_&Z*CDbFfym7y}@d$H+DLXL#86?YRI!8cV=g7W?qs5usZU>5%ZgFK!F%v2d z&uz}h8KJF}U%-i5Xk)aRL=#Yv5vqI~!i6oA0|{esj3<%Z_!?a+>?yN}bx>7pWK|^x zPU-jDv%i|S)trpV!Oco=vl_(YMLI-(Cu(EHWxf+}vu|K^Tha(|{7z-hHxL0zUzSSC z*~^@#EMl|&DzQ8353|)?98lqV3c0~B2j(rv1!09Z^0)uucnheRAh2L1o|8yDj}jXS zCu7|h`wg*k)^jMrFlUj$Juf4)lew+c#|{l(;>rZ2K1PU|Wd#nOSkKUrT%8y^}R)fd>s2C=Kn=(X|5 zF4sQEgqmDwwoy0ThWEm%v;bGyyjaq@P|`XlC?%b0N$1R&dsQ3MsxGB!r&_g>mcAEP z$<>?ZwlB6HS!h3Uw_0gGqqd(>iqERWXJ<~Q%Gb@DzR~~c`BX(690RZR0cYce#fI*M zhHiORue|d_YUiQU&J%Fp25D=gmGT5>A5ER&ubUZ~tyLQK+?t#_@Mr$onJcp)`|&@a z(;)EMF9r;@Ndl(nc{NUnF;coW7n&oP_bU&BsOUNY*i9%cB4l90$Y zsr)9{xSucymS}!R+Dl+Rz}kr|9pi2i+cR>2ZWm&k4q1lSCXxLwKCxY%#ENVJIfKNa zf-*^LPOTEz{aZbe3T6|joS5lM@YQ1$vN1FAC-g}NsjVC$aF{?3f#U>v3E*XM+#Fo# zMI@3*xTKR5P748~(*({CAX@}dAAs9pkw!*7CPs4+kjb67iB9fRJ=9=HNO+=T$Bo+D zEx#MTn+NjJ$ZSai+DT2C#+Y-%w{lUKwNaAQPz;hen5^^mkDTCAbhxF$xk$x)T2Cfs6 zurGUG@qLB*?|RrE9|v+SJxg50+WwG^C6X=(hTF-lx!Zhp?B}Sh{}HgxaH&wiVyI~$)bw`g zH*3`9!>Q&2XpP~6Sfdbzj)(U{y##PS5dPxG7e?68$&Cx)#<#I5{3)gRklK7mZaSQ5 z>b?`Z^GPJkqb0g|I4as3#9{6Lv|#n{lHb6?kUb=f@0+--T-;nAeDHyrx8?V!1TXT= zp`)F?A9V6ZcLjg2*8}K~Mm}!XR>PJqqB&{BOjBlW!gPR!^h*enJ_BHHbGqu^XhnKL zYI0jvZgOViPp_fH9inf0Xr&g{b>|azPEw1brkZzmdJggX3K zfE!4+I9u|JSkh-vLyiV!iiX52<5HuPiLmrc3(<1$ z5wIP{_z9kE+jpn7m?_PqggH4EEJ$bDSGu$O-w$SCUuF2{TJN z$xrs%juSCr7tDE*ar591oE?HYTdpvPDQL>7+d5hq>LB|{P|rzUZd-kG@DjFEMrx+! zpcavJvA*-saSDu}vKFX$y0E0uYL%a)R>`LQKSi!j6Szcx#CWK<`cYKI{@TSWt->+H z8|`SaJ$G9(NV5@_h)Vp=fjqmpEkD}({7*xpM=+E6V$qA-)Ftoi{bS$V2_<}94WCDI z3_eY383C~zn4sT6prgT}Wgx?cNFAC^Yuq1EID8v zFOr8{C{Jq>M~OI*HqjEwHNEY}BRJs9(h<^w5dgea0DDLes-qLb!`7fmPwpy2U7pX7 z=B*wVUrwISo$5vEY}snWuu{^kmUOSAMyyrZ&#CR_l;ZPh@p-5bt7@P<#5rE&sjAwv z>we4T#g=^wE&JpHXK{>FYX9++ah?>^h#UJzjR;7i_IkzNFmrWwi_)|&Sv&WcKl3-u zT$_Cgx&hM+M>zRh^0>>898O+H_S%m4GKOmR;XaqTFqglb4Zy)-xU+|QleK*u5-N>Q zi~101#9y_O%yV*p?zwD|nGRWoWF}FNfAVTX{iqiR2sRsDsGv2d5!XZ@U0Hf}@)&?% zlA-olZq6keGh_b)ebP8J7^XpFluDvUyrQl=d{^HZExFJsi6mBurqiXD3A_T}raWwR zS038Q&F5Q9Xv*4X$1gsO_zgRaFy?i$X@q@Vx6P~hb(;jFE%O;}Kg&p4^<9061fQ?a6j|le2KD0oK&zp9H(G zrY=u4yIPHx`;uGOL@{+AEsU`lhr<( zI&kFf`uQ!yXrE+^_DKh$?Z#ohec`=2wQZk5%bE7eEeFgEYnEEbUr2x#!g6MBUW#e)&n!TEcHe$4=V-@ zLkqdPCe^RrH67E0MuBYIbD3$v?AW)dhtj$u^w81#=xLo&gPH(sAh=W7@dffA7ao2E1eT$k^}=M{nE}+EbQ}Pc^xN) zaxUz{ED`Ivf$1Ro6{bVLXpX}Uj4*8{=4Ee&&KW;|(-qyuHFxZSSg?lJew3Utk{eM6zij`M5;4L4?O08IvMPSd7P3o1uzH!mj6NBHmUipMJLmdyQHrm zrSw%w?$h0`lgBK9Z=t%`z8}PAzet?^m2~NsHWK2{5OWx}$OSC9FiK%8OEy;}cbY4e zO|PzOuWrql17mMw&VjKfTTqv!$wLs93o@b-r=@<~j-c*1Em2QXBmC&g1ob)Ok%{5N z5G-yR6fM5o9CB9VKO}N=7WD;kRNaD&=c(cf+?aJZ=^?vrI;;s|(xC=NzXQm81=cdx z;H>vZt#ZrH8g#DCy$ae%a%k16() z^#_v&?&P~LA$iHmbC>0fyA*!6%I}u>-FDvNa`ICy3ki|L)jvP^;#_m`vchju`E4@4 z&BBC`q^*&g|z@k%sRX-M?CN@ zj%7+TY00kJ8JqPJrA7$IF1NU^w6x60@^PqMTl(BoOoNImTzBmWX(5f1;?X_=((|YoUJm_XC6rqkw`19 z+tK%s%OWaU5vmpHHU=a8*V)!SX*rz79Y6yGC0{%({V3~-PAKeSDQ^Q;C6&ByTCa>#|Yqu+fJJiA*%!6zez5U!b zKP4A-D1{wrVaIp#)g5PlR8;e3hun4mc6U1u{^;Q8R8i@T*uxN?SAo?ac@+`0>4Ax_Xspm)|&}zVlMLA4A4v;YzdAT^KU6&+ZL#A z5iB)!1E%z?6kR8=r$PywQUj+zBzf97LU4&M@wF~-01y3{zMCu*Y=LmsY0QTzfir4= zG@kr9hThxI%rC=`5KY1?EPl>2=&hi!xAmc|rYxh(I@ zRe~a4`w|D>ECKnhw_buxPwYI4xeFz5P7R!+m!SC}f=hguPyi4Lr*E}|f@v>z=Bi3N z3af(2x7r9|Su9oOQD!EnE%R+z;!MyU^r0hU1HCa&6QB*hBT*OtiL|w9!AL|L8^C^F zp?_isBl~Dv%Re%HJz_r9M8>a^f~`ZRP&g%^LyEmD4JBe)$n?^JhF=(onLaKmNuNN= z17+)v53^$iuw`DN5kcSg>xYVC{JM0B;t=mOF_7reElPEajmO!JFpZ=#n?K=Y_QiE_ zN8^D+`|7l?o)FsLm-V%p>nx8Blm~VMJ1*1yVrkdDPQ;UWriEQ;kwJS)}_O8ZSh^)>;jXGRh< zo~k;dJN{I7O5bOFNFtpAZY~}`)xh4G$Fo$&d3MR%&inR0YV}rf$y~8{-{+cT=bhp< z$<8~)mCDAQ;U%&O60zk)3ynYn7dMitCV_cZ#c$jT`f#o~{h6_l2Sve~JJ? zD;rXw2ASQdQ1y+*8;z;zZ8Ez}$5LLP#tRAzP{>s_r9u@q>TcAfLZvrqZq%fzTW;mw z%1{3wDo;@X5VbxP$}^+YrmAb5!Hp)SJcMNhuV<-q1v$%*ox0h`*eGW6t_Gl=cn-mO;Z7Ye*o#RibM3T%v7haH5q$LZ#(@Tth{;R&#*sOtLAGF Iz$n%K1KI}@od5s; literal 0 HcmV?d00001 diff --git a/struct2tensor/ops/file_descriptor_set.py b/struct2tensor/ops/file_descriptor_set.py index 75dcbd7..73dbbaf 100644 --- a/struct2tensor/ops/file_descriptor_set.py +++ b/struct2tensor/ops/file_descriptor_set.py @@ -22,10 +22,9 @@ """ from typing import Sequence, Set -from struct2tensor import path +from google.protobuf import descriptor, descriptor_pb2 -from google.protobuf import descriptor_pb2 -from google.protobuf import descriptor +from struct2tensor import path def _are_dependencies_handled(file_descriptor: descriptor.FileDescriptor, @@ -42,10 +41,13 @@ def _order_dependencies(file_descriptors: Set[descriptor.FileDescriptor] """Given a set of file descriptors, return them as an ordered list. Each file descriptor in the resulting list follows its dependencies. + Args: file_descriptors: a list of file descriptors. + Returns: file descriptors as an ordered list. + Raises: ValueError: if there is a circular dependency. """ diff --git a/struct2tensor/ops/file_descriptor_set_test.py b/struct2tensor/ops/file_descriptor_set_test.py index afcdd41..6403d3a 100644 --- a/struct2tensor/ops/file_descriptor_set_test.py +++ b/struct2tensor/ops/file_descriptor_set_test.py @@ -15,15 +15,15 @@ """Tests for struct2tensor.ops.file_descriptor_set.""" from absl.testing import absltest + from struct2tensor.ops import file_descriptor_set -from struct2tensor.test import dependent_test_pb2 # Notice that while text_extension_pb2 is not used directly, # it must be linked in and imported so that the extension can be found in the # pool. -from struct2tensor.test import test_extension_pb2 # pylint: disable=unused-import -from struct2tensor.test import test_map_pb2 -from struct2tensor.test import test_pb2 +from struct2tensor.test import ( + test_map_pb2, +) def _get_base_directory(): diff --git a/struct2tensor/ops/struct2tensor_ops.py b/struct2tensor/ops/struct2tensor_ops.py index 7ba0ac3..cff8385 100644 --- a/struct2tensor/ops/struct2tensor_ops.py +++ b/struct2tensor/ops/struct2tensor_ops.py @@ -16,17 +16,19 @@ from typing import NamedTuple, Optional, Sequence, Tuple -from struct2tensor import path -from struct2tensor.ops import file_descriptor_set -from struct2tensor.ops import gen_decode_proto_map_op -from struct2tensor.ops import gen_decode_proto_sparse -from struct2tensor.ops import gen_equi_join_any_indices -from struct2tensor.ops import gen_equi_join_indices -from struct2tensor.ops import gen_run_length_before import tensorflow as tf - from google.protobuf import descriptor +from struct2tensor import path +from struct2tensor.ops import ( + file_descriptor_set, + gen_decode_proto_map_op, + gen_decode_proto_sparse, + gen_equi_join_any_indices, + gen_equi_join_indices, + gen_run_length_before, +) + def _get_dtype_from_cpp_type(cpp_type: int) -> tf.DType: """Converts a cpp type in FieldDescriptor to the appropriate dtype.""" @@ -52,10 +54,11 @@ def _get_dtype_from_cpp_type(cpp_type: int) -> tf.DType: # field_descriptor: descriptor.FieldDescriptor (note: not used in V2). # value: tf.Tensor # index: tf.Tensor -_ParsedField = NamedTuple( - "_ParsedField", [("field_name", str), - ("field_descriptor", Optional[descriptor.FieldDescriptor]), - ("value", tf.Tensor), ("index", tf.Tensor)]) +class _ParsedField(NamedTuple): + field_name: str + field_descriptor: Optional[descriptor.FieldDescriptor] + value: tf.Tensor + index: tf.Tensor def parse_full_message_level( @@ -69,6 +72,7 @@ def parse_full_message_level( If there is a field with a message type, it is parsed as a string. Then, the function can be applied recursively. Note: this will not extract extensions. + Args: tensor_of_protos: a 1-D tensor of strings of protocol buffers. descriptor_type: a descriptor for the protocol buffer to parse. See @@ -83,6 +87,7 @@ def parse_full_message_level( "optional" or "repeated" label) is requested to be parsed, it will always have a value for each input parent message. If a value is not present on wire, the default value (0 or "") will be used. + Returns: list of named tuples, one per field_name in field_names: field_name: the string from field_names. @@ -92,7 +97,6 @@ def parse_full_message_level( value value[i]. Note that sometimes index[i]=index[i+1], implying a repeated field field_name. """ - field_names = [field.name for field in descriptor_type.fields] return parse_message_level( tensor_of_protos, @@ -138,6 +142,7 @@ def parse_message_level( "optional" or "repeated" label) is requested to be parsed, it will always have a value for each input parent message. If a value is not present on wire, the default value (0 or "") will be used. + Returns: list of named _ParsedField, one per field_name in field_names: field_name: the string from field_names. @@ -200,7 +205,6 @@ def parse_message_level( def run_length_before(a: tf.Tensor) -> tf.Tensor: r"""Returns the run length of each set of elements in a vector. - Args: a: a 1D int64 tensor. This assumes that for all a_i, a_j, if i <= j, then a_i <= a_j. diff --git a/struct2tensor/ops/struct2tensor_ops_test.py b/struct2tensor/ops/struct2tensor_ops_test.py index 351a93b..71f2d21 100644 --- a/struct2tensor/ops/struct2tensor_ops_test.py +++ b/struct2tensor/ops/struct2tensor_ops_test.py @@ -15,18 +15,19 @@ import itertools -from absl.testing import absltest -from absl.testing import parameterized import numpy as np -from struct2tensor.ops import struct2tensor_ops -from struct2tensor.test import test_extension_pb2 -from struct2tensor.test import test_map_pb2 -from struct2tensor.test import test_pb2 -from struct2tensor.test import test_proto3_pb2 import tensorflow as tf +from absl.testing import absltest, parameterized +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) - -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import +from struct2tensor.ops import struct2tensor_ops +from struct2tensor.test import ( + test_extension_pb2, + test_map_pb2, + test_pb2, +) INDEX = "index" VALUE = "value" @@ -143,8 +144,7 @@ def test_out_of_order_fields(self): for f in parsed_fields: self.assertAllEqual( expected_field_value[f.field_name], f.value, - "field: {}, permutation: {}, field_to_parse: {}".format( - f.field_name, indices, comb)) + f"field: {f.field_name}, permutation: {indices}, field_to_parse: {comb}") def test_out_of_order_repeated_fields_1(self): # This is a 2-1-2 wire number pattern. @@ -557,9 +557,9 @@ def _parse_map_entry(self, messages_with_map, map_field_name, keys_needed): @parameterized.named_parameters( [dict(testcase_name=t, key_type=t) for t in _SIGNED_INTEGER_TYPES]) def test_signed_integer_key_types(self, key_type): - field_name = "{}_string_map".format(key_type) + field_name = f"{key_type}_string_map" message_with_map = test_map_pb2.MessageWithMap() - map_entry = getattr(message_with_map, "{}_string_map".format(key_type)) + map_entry = getattr(message_with_map, f"{key_type}_string_map") map_entry[42] = "hello" map_entry[-42] = "world" @@ -578,9 +578,9 @@ def test_signed_integer_key_types(self, key_type): @parameterized.named_parameters( [dict(testcase_name=t, key_type=t) for t in _UNSIGNED_INTEGER_TYPES]) def test_unsigned_integer_key_types(self, key_type): - field_name = "{}_string_map".format(key_type) + field_name = f"{key_type}_string_map" message_with_map = test_map_pb2.MessageWithMap() - map_entry = getattr(message_with_map, "{}_string_map".format(key_type)) + map_entry = getattr(message_with_map, f"{key_type}_string_map") map_entry[42] = "hello" [(values_42, indices_42), @@ -592,14 +592,14 @@ def test_unsigned_integer_key_types(self, key_type): self.assertAllEqual(indices_0, []) def test_invalid_uint32_key(self): - with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, "Failed to parse .*string"): self.evaluate( self._parse_map_entry([test_map_pb2.MessageWithMap()], "uint32_string_map", ["-42"])) def test_invalid_int32_key(self): - with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, "Failed to parse .*string"): self.evaluate( self._parse_map_entry([test_map_pb2.MessageWithMap()], @@ -617,7 +617,7 @@ def test_bool_key_type(self): def test_invalid_bool_key(self): message_with_map = test_map_pb2.MessageWithMap() - with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, "Failed to parse .*string"): self.evaluate( self._parse_map_entry([message_with_map], "bool_string_map", ["2"])) @@ -625,9 +625,9 @@ def test_invalid_bool_key(self): @parameterized.named_parameters( [dict(testcase_name=t, value_type=t) for t in _SIGNED_INTEGER_TYPES]) def test_signed_integer_value_types(self, value_type): - field_name = "string_{}_map".format(value_type) + field_name = f"string_{value_type}_map" message_with_map = test_map_pb2.MessageWithMap() - map_entry = getattr(message_with_map, "string_{}_map".format(value_type)) + map_entry = getattr(message_with_map, f"string_{value_type}_map") map_entry["foo"] = 42 map_entry["bar"] = -42 [(values_foo, indices_foo), (values_bar, indices_bar), @@ -644,9 +644,9 @@ def test_signed_integer_value_types(self, value_type): @parameterized.named_parameters( [dict(testcase_name=t, value_type=t) for t in _UNSIGNED_INTEGER_TYPES]) def test_unsigned_integer_value_types(self, value_type): - field_name = "string_{}_map".format(value_type) + field_name = f"string_{value_type}_map" message_with_map = test_map_pb2.MessageWithMap() - map_entry = getattr(message_with_map, "string_{}_map".format(value_type)) + map_entry = getattr(message_with_map, f"string_{value_type}_map") map_entry["foo"] = 42 [(values_foo, indices_foo), (values_null, indices_null) ] = self._parse_map_entry([message_with_map], field_name, ["foo", ""]) @@ -658,9 +658,9 @@ def test_unsigned_integer_value_types(self, value_type): @parameterized.named_parameters( [dict(testcase_name=t, value_type=t) for t in ["float", "double"]]) def test_fp_value_types(self, value_type): - field_name = "string_{}_map".format(value_type) + field_name = f"string_{value_type}_map" message_with_map = test_map_pb2.MessageWithMap() - map_entry = getattr(message_with_map, "string_{}_map".format(value_type)) + map_entry = getattr(message_with_map, f"string_{value_type}_map") map_entry["foo"] = 0.5 [(values_foo, indices_foo), (values_null, indices_null) ] = self._parse_map_entry([message_with_map], field_name, ["foo", ""]) diff --git a/struct2tensor/path.py b/struct2tensor/path.py index 584e185..698e975 100644 --- a/struct2tensor/path.py +++ b/struct2tensor/path.py @@ -20,7 +20,7 @@ """ import re -from typing import Sequence, Tuple, Union, List +from typing import List, Sequence, Tuple, Union from google.protobuf import descriptor from tensorflow_metadata.proto.v0 import path_pb2 as tf_metadata_path_pb2 @@ -76,7 +76,7 @@ def _compare_step(a: Step, b: Step) -> int: return -1 -class Path(object): +class Path: """A representation of a path in the expression. Do not implement __nonzero__, __eq__, __ne__, et cetera as these are @@ -332,6 +332,7 @@ def create_path(path_source: CoercableToPath) -> Path: Returns: A Path. + Raises: ValueError: if this is not a valid path. """ diff --git a/struct2tensor/path_test.py b/struct2tensor/path_test.py index a7eebe5..dcf7423 100644 --- a/struct2tensor/path_test.py +++ b/struct2tensor/path_test.py @@ -16,14 +16,15 @@ # pylint: disable=protected-access import pprint -from absl.testing import absltest -from absl.testing import parameterized - -from struct2tensor.path import create_path -from struct2tensor.path import expand_wildcard_proto_paths -from struct2tensor.path import from_proto -from struct2tensor.path import parse_map_indexing_step -from struct2tensor.path import Path +from absl.testing import absltest, parameterized + +from struct2tensor.path import ( + Path, + create_path, + expand_wildcard_proto_paths, + from_proto, + parse_map_indexing_step, +) from struct2tensor.test import test_pb2 _FORMAT_TEST_CASES = [ diff --git a/struct2tensor/prensor.py b/struct2tensor/prensor.py index d80d96d..e7c4786 100644 --- a/struct2tensor/prensor.py +++ b/struct2tensor/prensor.py @@ -25,18 +25,19 @@ import enum from typing import FrozenSet, Iterator, List, Mapping, Optional, Sequence, Tuple, Union -from struct2tensor import calculate_options -from struct2tensor import path -from struct2tensor.ops import struct2tensor_ops import tensorflow as tf +from tensorflow.python.framework import ( + composite_tensor, # pylint: disable=g-direct-tensorflow-import +) -from tensorflow.python.framework import composite_tensor # pylint: disable=g-direct-tensorflow-import +from struct2tensor import calculate_options, path +from struct2tensor.ops import struct2tensor_ops # TODO(martinz): Consider creating node.py with the LeafNodeTensor, # ChildNodeTensor, and RootNodeTensor, allowing expression.py to depend upon # node.py. -class RootNodeTensor(object): +class RootNodeTensor: """The value of the root.""" __slots__ = ["_size"] @@ -72,7 +73,7 @@ def __str__(self): return "RootNodeTensor" -class ChildNodeTensor(object): +class ChildNodeTensor: """The value of an intermediate node.""" __slots__ = ["_parent_index", "_is_repeated", "_index_to_value"] @@ -142,10 +143,10 @@ def get_positional_index(self) -> tf.Tensor: def __str__(self): cardinality = "repeated" if self.is_repeated else "optional" - return "{} ChildNodeTensor".format(cardinality) + return f"{cardinality} ChildNodeTensor" -class LeafNodeTensor(object): +class LeafNodeTensor: """The value of a leaf node.""" __slots__ = ["_parent_index", "_values", "_is_repeated"] @@ -386,8 +387,7 @@ def get_child_or_error(self, field_name: path.Step) -> "Prensor": result = self._children.get(field_name) if result is not None: return result - raise ValueError("Field not found: {} in {}".format( - str(field_name), str(self))) + raise ValueError(f"Field not found: {str(field_name)} in {str(self)}") def get_descendant(self, p: path.Path) -> Optional["Prensor"]: """Finds the descendant at the path.""" @@ -402,7 +402,7 @@ def get_descendant_or_error(self, p: path.Path) -> "Prensor": """Finds the descendant at the path.""" result = self.get_descendant(p) if result is None: - raise ValueError("Missing path: {} in {}".format(str(p), str(self))) + raise ValueError(f"Missing path: {str(p)} in {str(self)}") return result def get_children(self) -> "collections.OrderedDict[path.Step, Prensor]": @@ -504,10 +504,10 @@ def _string_helper(self, field_name: path.Step) -> Sequence[str]: Returns: lines to run __str__, that are bytes in Python 2 and unicode in Python 3. """ - result = ["{} {}".format(str(self.node), str(field_name))] + result = [f"{str(self.node)} {str(field_name)}"] for k, v in self._children.items(): recursive = v._string_helper(k) # pylint: disable=protected-access - result.extend([" {}".format(x) for x in recursive]) + result.extend([f" {x}" for x in recursive]) return result def __str__(self) -> str: @@ -564,7 +564,7 @@ def create_prensor_from_descendant_nodes( subexpressions[first_step] = {} subexpressions[first_step][suffix] = v if root_node is None: - raise ValueError("No root found: {}".format(str(nodes))) + raise ValueError(f"No root found: {str(nodes)}") return create_prensor_from_root_and_children( root_node, collections.OrderedDict((k, create_prensor_from_descendant_nodes(v)) @@ -582,7 +582,7 @@ def create_prensor_from_root_and_children( ######## Below code is for converting prensor to ragged/sparse tensors.######### -class _LeafNodePath(object): +class _LeafNodePath: """A path ending in a leaf. In order to avoid type checks and casting in the heart of different methods @@ -611,7 +611,7 @@ def tail(self) -> LeafNodeTensor: return self._tail -class _ChildNodePath(object): +class _ChildNodePath: """A _ChildNodePath is a path that ends with a child node. It keeps same triple structure as _LeafNodePath. @@ -642,15 +642,14 @@ def _as_root_node_tensor(node_tensor: NodeTensor) -> RootNodeTensor: return node_tensor if isinstance(node_tensor, ChildNodeTensor): return RootNodeTensor(node_tensor.size) - raise ValueError("Must be child or root node tensor (found {})".format( - type(node_tensor))) + raise ValueError(f"Must be child or root node tensor (found {type(node_tensor)})") def _get_leaf_node_path(p: path.Path, t: Prensor) -> _LeafNodePath: """Creates a _LeafNodePath to p.""" leaf_node = t.get_descendant_or_error(p).node if not isinstance(leaf_node, LeafNodeTensor): - raise ValueError("Expected Leaf Node at {} in {}".format(str(p), str(t))) + raise ValueError(f"Expected Leaf Node at {str(p)} in {str(t)}") if not p: raise ValueError("Leaf should not be at the root") # If there is a leaf at the root, this will return a ValueError. @@ -784,7 +783,6 @@ def _get_sparse_tensors( Returns: A map from paths to sparse tensors. """ - del options return { p: _get_sparse_tensor_from_leaf_node_path(v) @@ -799,7 +797,7 @@ def from_value_rowids_bridge(values, value_rowids=None, nrows=None, validate=True): - """validate option is only available internally for tf 0.13.1.""" + """Validate option is only available internally for tf 0.13.1.""" return tf.RaggedTensor.from_value_rowids( values, value_rowids=value_rowids, nrows=nrows, validate=validate) diff --git a/struct2tensor/prensor_test.py b/struct2tensor/prensor_test.py index cc72f01..2dd27c7 100644 --- a/struct2tensor/prensor_test.py +++ b/struct2tensor/prensor_test.py @@ -14,13 +14,13 @@ """Tests for struct2tensor.prensor.""" -from struct2tensor import calculate_options -from struct2tensor import path -from struct2tensor import prensor -from struct2tensor.test import prensor_test_util import tensorflow as tf +from tensorflow.python.framework import ( + test_util, # pylint: disable=g-direct-tensorflow-import +) -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import +from struct2tensor import calculate_options, path, prensor +from struct2tensor.test import prensor_test_util _OPTIONS_TO_TEST = [ calculate_options.get_default_options(), diff --git a/struct2tensor/prensor_to_structured_tensor.py b/struct2tensor/prensor_to_structured_tensor.py index 74b1d38..fd742a8 100644 --- a/struct2tensor/prensor_to_structured_tensor.py +++ b/struct2tensor/prensor_to_structured_tensor.py @@ -29,12 +29,15 @@ from typing import Mapping, Union -from struct2tensor import path -from struct2tensor import prensor import tensorflow as tf - -from tensorflow.python.ops.ragged.row_partition import RowPartition # pylint: disable=g-direct-tensorflow-import -from tensorflow.python.ops.structured import structured_tensor # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.ops.ragged.row_partition import ( + RowPartition, # pylint: disable=g-direct-tensorflow-import +) +from tensorflow.python.ops.structured import ( + structured_tensor, # pylint: disable=g-direct-tensorflow-import +) + +from struct2tensor import path, prensor def prensor_to_structured_tensor( diff --git a/struct2tensor/prensor_to_structured_tensor_test.py b/struct2tensor/prensor_to_structured_tensor_test.py index b170d5b..52f0069 100644 --- a/struct2tensor/prensor_to_structured_tensor_test.py +++ b/struct2tensor/prensor_to_structured_tensor_test.py @@ -13,14 +13,12 @@ # limitations under the License. """Tests for StructuredTensor.""" +import tensorflow as tf from google.protobuf import text_format -from struct2tensor import calculate -from struct2tensor import prensor -from struct2tensor import prensor_to_structured_tensor + +from struct2tensor import calculate, prensor, prensor_to_structured_tensor from struct2tensor.expression_impl import proto -from struct2tensor.test import prensor_test_util -from struct2tensor.test import test_pb2 -import tensorflow as tf +from struct2tensor.test import prensor_test_util, test_pb2 # @test_util.run_all_in_graph_and_eager_modes @@ -47,13 +45,13 @@ def test_nested_prensor(self): def test_big_prensor(self): """Test the big prensor. - a prensor expression representing: - {foo:9, foorepeated:[9], doc:[{bar:["a"], keep_me:False}], - user:[{friends:["a"]}]} - {foo:8, foorepeated:[8,7], - doc:[{bar:["b","c"],keep_me:True},{bar:["d"]}], - user:[{friends:["b", "c"]},{friends:["d"]}],} - {foo:7, foorepeated:[6], user:[friends:["e"]]} + a prensor expression representing: + {foo:9, foorepeated:[9], doc:[{bar:["a"], keep_me:False}], + user:[{friends:["a"]}]} + {foo:8, foorepeated:[8,7], + doc:[{bar:["b","c"],keep_me:True},{bar:["d"]}], + user:[{friends:["b", "c"]},{friends:["d"]}],} + {foo:7, foorepeated:[6], user:[friends:["e"]]} """ pren = prensor_test_util.create_big_prensor() st = prensor_to_structured_tensor.prensor_to_structured_tensor(pren) @@ -70,13 +68,13 @@ def test_big_prensor(self): def test_deep_prensor(self): """Test a prensor with three layers: root, event, and doc. - a prensor expression representing: - {foo:9, foorepeated:[9], user:[{friends:["a"]}], - event:{doc:[{bar:["a"], keep_me:False}]}} - {foo:8, foorepeated:[8,7], - event:{doc:[{bar:["b","c"], keep_me:True},{bar:["d"]}]}, - user:[{friends:["b", "c"]}, {friends:["d"]}]} - {foo:7, foorepeated:[6], user:[friends:["e"]], event:{}} + a prensor expression representing: + {foo:9, foorepeated:[9], user:[{friends:["a"]}], + event:{doc:[{bar:["a"], keep_me:False}]}} + {foo:8, foorepeated:[8,7], + event:{doc:[{bar:["b","c"], keep_me:True},{bar:["d"]}]}, + user:[{friends:["b", "c"]}, {friends:["d"]}]} + {foo:7, foorepeated:[6], user:[friends:["e"]], event:{}} """ pren = prensor_test_util.create_deep_prensor() st = prensor_to_structured_tensor.prensor_to_structured_tensor(pren) @@ -95,7 +93,7 @@ def test_deep_prensor(self): def test_non_root_prensor(self): child_prensor = prensor.create_prensor_from_root_and_children( prensor_test_util.create_child_node([0, 0, 1, 3, 7], True), {}) - with self.assertRaisesRegexp(ValueError, "Must be a root prensor"): + with self.assertRaisesRegex(ValueError, "Must be a root prensor"): prensor_to_structured_tensor.prensor_to_structured_tensor(child_prensor) def test_e2e_proto(self): diff --git a/struct2tensor/prensor_util.py b/struct2tensor/prensor_util.py index 234e6d6..6e2abb6 100644 --- a/struct2tensor/prensor_util.py +++ b/struct2tensor/prensor_util.py @@ -20,12 +20,12 @@ from typing import Mapping -from struct2tensor import calculate_options -from struct2tensor import path -from struct2tensor import prensor import tensorflow as tf +from tensorflow.python.util import ( + deprecation, # pylint: disable=g-direct-tensorflow-import +) -from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import +from struct2tensor import calculate_options, path, prensor @deprecation.deprecated(None, "Use the Prensor class method instead.") diff --git a/struct2tensor/prensor_value.py b/struct2tensor/prensor_value.py index c424eeb..eb57ef8 100644 --- a/struct2tensor/prensor_value.py +++ b/struct2tensor/prensor_value.py @@ -28,14 +28,15 @@ from typing import FrozenSet, Iterator, Mapping, Optional, Sequence, Union import numpy as np -from struct2tensor import path -from struct2tensor import prensor import tensorflow as tf +from tensorflow.python.client import ( + session as session_lib, # pylint: disable=g-direct-tensorflow-import +) -from tensorflow.python.client import session as session_lib # pylint: disable=g-direct-tensorflow-import +from struct2tensor import path, prensor -class RootNodeValue(object): +class RootNodeValue: """The value of the root.""" __slots__ = ["_size"] @@ -60,13 +61,13 @@ def schema_string(self): return "repeated" def data_string(self): - return "size: {}".format(self._size) + return f"size: {self._size}" def __str__(self): return "RootNode" -class ChildNodeValue(object): +class ChildNodeValue: """The value of an intermediate node.""" __slots__ = ["_parent_index", "_is_repeated"] @@ -104,13 +105,13 @@ def schema_string(self) -> str: return "repeated" if self.is_repeated else "optional" def data_string(self): - return "parent_index: {}".format(self._parent_index) + return f"parent_index: {self._parent_index}" def __str__(self): - return "ChildNode {} {}".format(self.schema_string(), self.data_string()) + return f"ChildNode {self.schema_string()} {self.data_string()}" -class LeafNodeValue(object): +class LeafNodeValue: """The value of a leaf node.""" __slots__ = ["_parent_index", "_values", "_is_repeated"] @@ -143,11 +144,10 @@ def values(self): return self._values def data_string(self): - return "parent_index: {} values: {}".format(self._parent_index, - self._values) + return f"parent_index: {self._parent_index} values: {self._values}" def schema_string(self) -> str: - return u"{} {}".format("repeated" if self.is_repeated else "optional", + return "{} {}".format("repeated" if self.is_repeated else "optional", str(self.values.dtype)) def __str__(self): @@ -158,7 +158,7 @@ def __str__(self): NodeValue = Union[RootNodeValue, ChildNodeValue, LeafNodeValue] # pylint: disable=invalid-name -class PrensorValue(object): +class PrensorValue: """A tree of NodeValue objects.""" __slots__ = ["_node", "_children"] @@ -195,7 +195,7 @@ def get_child_or_error(self, field_name: path.Step) -> "PrensorValue": result = self._children.get(field_name) if result is not None: return result - raise ValueError("Field not found: {}".format(str(field_name))) + raise ValueError(f"Field not found: {str(field_name)}") def get_descendant(self, p: path.Path) -> Optional["PrensorValue"]: """Finds the descendant at the path.""" @@ -210,7 +210,7 @@ def get_descendant_or_error(self, p: path.Path) -> "PrensorValue": """Finds the descendant at the path.""" result = self.get_descendant(p) if result is None: - raise ValueError("Missing path: {}".format(str(p))) + raise ValueError(f"Missing path: {str(p)}") return result def get_children(self) -> Mapping[path.Step, "PrensorValue"]: @@ -233,25 +233,24 @@ def field_names(self) -> FrozenSet[path.Step]: def _string_helper(self, field_name: str) -> Sequence[str]: """Helper for __str__ that outputs a list of lines.""" result = [ - "{} {} {}".format(self.node.schema_string(), str(field_name), - self.node.data_string()) + f"{self.node.schema_string()} {str(field_name)} {self.node.data_string()}" ] for k, v in self._children.items(): recursive = v._string_helper(k) # pylint: disable=protected-access - result.extend([" {}".format(x) for x in recursive]) + result.extend([f" {x}" for x in recursive]) return result def _schema_string_helper(self, field_name: str) -> Sequence[str]: """Helper for __str__ that outputs a list of lines.""" - result = [u"{} {}".format(self.node.schema_string(), str(field_name))] + result = [f"{self.node.schema_string()} {str(field_name)}"] for k, v in self._children.items(): recursive = v._string_helper(k) # pylint: disable=protected-access - result.extend([u" {}".format(x) for x in recursive]) + result.extend([f" {x}" for x in recursive]) return result def schema_string(self): """Returns a string representing the schema of the Prensor.""" - return u"\n".join(self._schema_string_helper("")) + return "\n".join(self._schema_string_helper("")) def __str__(self): """Returns a string representing the schema of the Prensor.""" diff --git a/struct2tensor/prensor_value_test.py b/struct2tensor/prensor_value_test.py index 76cd84f..2b6aae3 100644 --- a/struct2tensor/prensor_value_test.py +++ b/struct2tensor/prensor_value_test.py @@ -13,12 +13,14 @@ # limitations under the License. """Tests for struct2tensor.create_expression.""" -from struct2tensor import path -from struct2tensor import prensor -from struct2tensor import prensor_value # pylint: disable=unused-import -from struct2tensor.test import prensor_test_util import tensorflow as tf +from struct2tensor import ( + path, + prensor, # pylint: disable=unused-import +) +from struct2tensor.test import prensor_test_util + # This is a V1 test. The prensor value is not meaningful for V2. class PrensorValueTest(tf.test.TestCase): diff --git a/struct2tensor/struct2tensor_module_test.py b/struct2tensor/struct2tensor_module_test.py index 73919ed..7fda60f 100644 --- a/struct2tensor/struct2tensor_module_test.py +++ b/struct2tensor/struct2tensor_module_test.py @@ -14,6 +14,7 @@ """Tests for struct2tensor.__init__.py.""" from absl.testing import absltest + import struct2tensor as s2t @@ -54,7 +55,6 @@ def test_importing_struct2tensor_modules(self): def test_importing_expression_impl_modules(self): """This tests that the expression_impl/__init__.py imports are found.""" - from struct2tensor import expression_impl # pylint: disable=g-import-not-at-top modules = [ diff --git a/struct2tensor/structured_tensor_to_prensor.py b/struct2tensor/structured_tensor_to_prensor.py index dd63b66..2f220d8 100644 --- a/struct2tensor/structured_tensor_to_prensor.py +++ b/struct2tensor/structured_tensor_to_prensor.py @@ -41,12 +41,15 @@ from typing import Mapping, Union -from struct2tensor import path -from struct2tensor import prensor import tensorflow.compat.v2 as tf +from tensorflow.python.ops.ragged.row_partition import ( + RowPartition, # pylint: disable=g-direct-tensorflow-import +) +from tensorflow.python.ops.structured import ( + structured_tensor, # pylint: disable=g-direct-tensorflow-import +) -from tensorflow.python.ops.ragged.row_partition import RowPartition # pylint: disable=g-direct-tensorflow-import -from tensorflow.python.ops.structured import structured_tensor # pylint: disable=g-direct-tensorflow-import +from struct2tensor import path, prensor def structured_tensor_to_prensor( @@ -130,6 +133,7 @@ def _expand_dims(st, axis): Returns: a tensor with one more dimension (see tf.expand_dims). + Raises: ValueError: if the axis is not valid. @@ -249,7 +253,6 @@ def _partition_if_not_vector(values: tf.Tensor, dtype: tf.dtypes.DType): Raises: ValueError: if the shape cannot be statically determined or is a scalar. """ - values_shape = values.shape assert values_shape is not None values_rank = values_shape.rank @@ -273,12 +276,14 @@ def _fully_partitioned_ragged_tensor(rt: Union[tf.RaggedTensor, tf.Tensor], A fully partitioned ragged tensor is: 1. A ragged tensor. 2. The final values are a vector. + Args: rt: input to coerce from RaggedTensor or Tensor. Must be at least 2D. dtype: requested dtype for partitions: tf.int64 or tf.int32. Returns: A ragged tensor where the flat values are a 1D tensor. + Raises: ValueError: if the tensor is 0D or 1D. """ diff --git a/struct2tensor/structured_tensor_to_prensor_test.py b/struct2tensor/structured_tensor_to_prensor_test.py index b4d4b86..f84206c 100644 --- a/struct2tensor/structured_tensor_to_prensor_test.py +++ b/struct2tensor/structured_tensor_to_prensor_test.py @@ -13,15 +13,17 @@ # limitations under the License. """Tests for StructuredTensor.""" -from absl.testing import parameterized - import numpy as np -from struct2tensor import path -from struct2tensor import structured_tensor_to_prensor import tensorflow.compat.v2 as tf - -from tensorflow.python.ops.ragged.row_partition import RowPartition # pylint: disable=g-direct-tensorflow-import -from tensorflow.python.ops.structured import structured_tensor # pylint: disable=g-direct-tensorflow-import +from absl.testing import parameterized +from tensorflow.python.ops.ragged.row_partition import ( + RowPartition, # pylint: disable=g-direct-tensorflow-import +) +from tensorflow.python.ops.structured import ( + structured_tensor, # pylint: disable=g-direct-tensorflow-import +) + +from struct2tensor import path, structured_tensor_to_prensor def _make_structured_tensor(shape, fields): diff --git a/struct2tensor/test/expression_test_util.py b/struct2tensor/test/expression_test_util.py index de09f9c..d8a0edc 100644 --- a/struct2tensor/test/expression_test_util.py +++ b/struct2tensor/test/expression_test_util.py @@ -18,15 +18,11 @@ from typing import FrozenSet, Mapping, Optional, Sequence -from struct2tensor import calculate -from struct2tensor import calculate_options -from struct2tensor import expression -from struct2tensor import path -from struct2tensor import prensor import tensorflow as tf - from tensorflow_metadata.proto.v0 import schema_pb2 +from struct2tensor import calculate, calculate_options, expression, path, prensor + def calculate_value_slowly( expr: expression.Expression, @@ -116,8 +112,7 @@ def __init__(self, def calculate_output(self): """The output returned by this expression.""" if self._calculate_output is None: - raise ValueError("Did not specify calculate_output for {}".format( - self._name)) + raise ValueError(f"Did not specify calculate_output for {self._name}") return self._calculate_output def get_source_expressions(self) -> Sequence[expression.Expression]: @@ -130,7 +125,7 @@ def calculate( options: calculate_options.Options, side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: if len(source_tensors) != len(self._expected_source_tensors): - raise ValueError("Unexpected number of inputs for {}.".format(self._name)) + raise ValueError(f"Unexpected number of inputs for {self._name}.") for i in range(len(source_tensors)): if self._expected_source_tensors[i] is not source_tensors[i]: raise ValueError("Error calculating " + self._name) diff --git a/struct2tensor/test/prensor_test_util.py b/struct2tensor/test/prensor_test_util.py index 90178b6..0ee7724 100644 --- a/struct2tensor/test/prensor_test_util.py +++ b/struct2tensor/test/prensor_test_util.py @@ -19,13 +19,12 @@ from typing import Any, Sequence -from struct2tensor import path -from struct2tensor import prensor import tensorflow as tf - from google.protobuf import text_format from tensorflow_metadata.proto.v0 import schema_pb2 +from struct2tensor import path, prensor + def create_root_node(size: int) -> prensor.RootNodeTensor: return prensor.RootNodeTensor(tf.constant(size, dtype=tf.int64)) diff --git a/struct2tensor/tools/build_docs.py b/struct2tensor/tools/build_docs.py index b4f5dc6..3881fa3 100644 --- a/struct2tensor/tools/build_docs.py +++ b/struct2tensor/tools/build_docs.py @@ -22,14 +22,12 @@ import shutil import tempfile -from absl import app -from absl import flags +import yaml +from absl import app, flags +from tensorflow_docs.api_generator import generate_lib, public_api + import struct2tensor as s2t from struct2tensor import expression_impl -from tensorflow_docs.api_generator import generate_lib -from tensorflow_docs.api_generator import public_api - -import yaml flags.DEFINE_string("output_dir", "/tmp/s2t_api", "Where to output the docs") diff --git a/struct2tensor/tools/docker_build/build_manylinux.sh b/struct2tensor/tools/docker_build/build_manylinux.sh index 6120da3..2a6a37c 100755 --- a/struct2tensor/tools/docker_build/build_manylinux.sh +++ b/struct2tensor/tools/docker_build/build_manylinux.sh @@ -117,4 +117,3 @@ set -x setup_environment && \ bazel_build && \ stamp_wheel - From 64d20310060ac3f7ecb0e16745964305c4bdf0a2 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 14 Aug 2025 19:53:27 -0700 Subject: [PATCH 02/17] Add `noqa` for Ruff rule E721 --- struct2tensor/expression_impl/proto.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/struct2tensor/expression_impl/proto.py b/struct2tensor/expression_impl/proto.py index 95a03ae..61ce4cd 100644 --- a/struct2tensor/expression_impl/proto.py +++ b/struct2tensor/expression_impl/proto.py @@ -415,7 +415,7 @@ def calculate_from_parsed_field( def calculation_equal(self, expr: expression.Expression) -> bool: # Ensure that we're dealing with the _ProtoChildExpression and not any # of its subclasses. - if type(expr) != _ProtoChildExpression: # pylint: disable=unidiomatic-typecheck + if type(expr) != _ProtoChildExpression: # noqa: E721 return False expr = cast(_ProtoChildExpression, expr) # Keep pytype happy. return (self._desc == expr._desc and # pylint: disable=protected-access @@ -439,8 +439,7 @@ def __init__(self, parent: "_ParentProtoExpression", desc: descriptor.Descriptor, is_repeated: bool, name_as_field: StrStep, transform_fn: TransformFn, backing_str_tensor: Optional[tf.Tensor]): - super(_TransformProtoChildExpression, - self).__init__(parent, desc, is_repeated, name_as_field, + super().__init__(parent, desc, is_repeated, name_as_field, backing_str_tensor) self._transform_fn = transform_fn From e72a36a0fffb6555d8252e06dfcf4425c934df0b Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 14 Aug 2025 19:54:14 -0700 Subject: [PATCH 03/17] Apply fixes for Ruff rule UP008 --- struct2tensor/expression_impl/filter_expression.py | 2 +- struct2tensor/expression_impl/index.py | 3 +-- struct2tensor/expression_impl/map_prensor_to_prensor.py | 3 +-- struct2tensor/expression_impl/parquet.py | 6 ++---- struct2tensor/expression_impl/parquet_test.py | 2 +- struct2tensor/expression_impl/raw_parquet_dataset_test.py | 2 +- 6 files changed, 7 insertions(+), 11 deletions(-) diff --git a/struct2tensor/expression_impl/filter_expression.py b/struct2tensor/expression_impl/filter_expression.py index dca3fe0..4b7d026 100644 --- a/struct2tensor/expression_impl/filter_expression.py +++ b/struct2tensor/expression_impl/filter_expression.py @@ -213,7 +213,7 @@ class _FilterChildByParentIndicesToKeepExpression(expression.Expression): def __init__(self, origin: expression.Expression, parent: expression.Expression): - super(_FilterChildByParentIndicesToKeepExpression, self).__init__( + super().__init__( origin.is_repeated, origin.type, validate_step_format=origin.validate_step_format, diff --git a/struct2tensor/expression_impl/index.py b/struct2tensor/expression_impl/index.py index 8ce3329..5154c21 100644 --- a/struct2tensor/expression_impl/index.py +++ b/struct2tensor/expression_impl/index.py @@ -237,8 +237,7 @@ def __init__(self, positional_index: expression.Expression, get_positional_index). size_inp: the size of the field (from size.size). """ - super(_PositionalIndexFromEndExpression, - self).__init__(positional_index.is_repeated, tf.int64) + super().__init__(positional_index.is_repeated, tf.int64) self._positional_index = positional_index self._size = size_inp diff --git a/struct2tensor/expression_impl/map_prensor_to_prensor.py b/struct2tensor/expression_impl/map_prensor_to_prensor.py index 0a66635..d95ecc5 100644 --- a/struct2tensor/expression_impl/map_prensor_to_prensor.py +++ b/struct2tensor/expression_impl/map_prensor_to_prensor.py @@ -315,8 +315,7 @@ class _PrensorAsLeafNodeTensor(prensor.LeafNodeTensor): def __init__(self, prensor_tree: prensor.Prensor, leaf: prensor.LeafNodeTensor): """Call _tree_as_node instead.""" - super(_PrensorAsLeafNodeTensor, - self).__init__(leaf.parent_index, leaf.values, leaf.is_repeated) + super().__init__(leaf.parent_index, leaf.values, leaf.is_repeated) self._prensor = prensor_tree @property diff --git a/struct2tensor/expression_impl/parquet.py b/struct2tensor/expression_impl/parquet.py index d85f8ee..144c16a 100644 --- a/struct2tensor/expression_impl/parquet.py +++ b/struct2tensor/expression_impl/parquet.py @@ -257,8 +257,7 @@ def __init__(self, filenames: List[str], value_paths: List[str], self._create_parent_index_paths_and_index_from_type_spec( self.element_structure, 0, 0) - super(ParquetDataset, - self).__init__(filenames, self._value_paths, self._value_dtypes, + super().__init__(filenames, self._value_paths, self._value_dtypes, self._parent_index_paths, self._path_index, batch_size) def _get_column_dtypes( @@ -506,8 +505,7 @@ def __init__(self, exprs: List[expression.Expression], parquet_paths = [".".join(p.field_list) for p in paths] - super(_ParquetDatasetWithExpression, - self).__init__(filenames, parquet_paths, batch_size) + super().__init__(filenames, parquet_paths, batch_size) def _calculate_prensor(self, pren) -> List[prensor.Prensor]: """Function for applying expression queries to a prensor. diff --git a/struct2tensor/expression_impl/parquet_test.py b/struct2tensor/expression_impl/parquet_test.py index efb0a2a..4a44c9a 100644 --- a/struct2tensor/expression_impl/parquet_test.py +++ b/struct2tensor/expression_impl/parquet_test.py @@ -30,7 +30,7 @@ class ParquetDatasetTestBase(tf.test.TestCase): def setUp(self): - super(ParquetDatasetTestBase, self).setUp() + super().setUp() self._test_filenames = [ "struct2tensor/testdata/parquet_testdata/dremel_example.parquet" ] diff --git a/struct2tensor/expression_impl/raw_parquet_dataset_test.py b/struct2tensor/expression_impl/raw_parquet_dataset_test.py index e1c826a..740c84f 100644 --- a/struct2tensor/expression_impl/raw_parquet_dataset_test.py +++ b/struct2tensor/expression_impl/raw_parquet_dataset_test.py @@ -30,7 +30,7 @@ class ParquetDatasetTestBase(tf.test.TestCase): """ def setUp(self): - super(ParquetDatasetTestBase, self).setUp() + super().setUp() self._test_filenames = [ "struct2tensor/testdata/parquet_testdata/dremel_example.parquet" ] From 27a6b79b200b88f56e5fbc255c29f572bc9cb28c Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 14 Aug 2025 19:58:11 -0700 Subject: [PATCH 04/17] Make compliant with Ruff rule UP032 --- struct2tensor/path.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/struct2tensor/path.py b/struct2tensor/path.py index 698e975..bf44456 100644 --- a/struct2tensor/path.py +++ b/struct2tensor/path.py @@ -238,7 +238,7 @@ def as_proto(self): elif isinstance(x, AnonymousId): raise ValueError("Cannot serialize a path with anonymous fields") else: - raise ValueError("Unexpected path element type: %s" % type(x)) + raise ValueError(f"Unexpected path element type: {type(x)}") return result def __repr__(self) -> str: From cb66c6c7ab69efb9dddb2017188e091eab1bd1fb Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 14 Aug 2025 19:59:31 -0700 Subject: [PATCH 05/17] Make compliant with Ruff rule B007 --- struct2tensor/expression_impl/parquet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/struct2tensor/expression_impl/parquet.py b/struct2tensor/expression_impl/parquet.py index 144c16a..a9775df 100644 --- a/struct2tensor/expression_impl/parquet.py +++ b/struct2tensor/expression_impl/parquet.py @@ -302,7 +302,7 @@ def _validate_file(self, filename: str, value_paths: List[str]): p = (col.path) paths[p] = col.physical_type - for i, p in enumerate(value_paths): + for p in value_paths: if p not in paths: raise ValueError("path " + p + " does not exist in the file.") From 8a236ca3ef6320b8ee01822a3d10ed750d05b1c8 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 14 Aug 2025 20:03:27 -0700 Subject: [PATCH 06/17] Make compliant with Ruff rule SIM117 --- .../benchmarks/struct2tensor_benchmark_util.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/struct2tensor/benchmarks/struct2tensor_benchmark_util.py b/struct2tensor/benchmarks/struct2tensor_benchmark_util.py index 1f57422..466262d 100644 --- a/struct2tensor/benchmarks/struct2tensor_benchmark_util.py +++ b/struct2tensor/benchmarks/struct2tensor_benchmark_util.py @@ -76,17 +76,20 @@ def run_benchmarks(self, fn_name, get_benchmark_fn, fn_args, data_key): iterations = 2 # 2 iterations so we can calculate stdev. sample_size = 1 - with tf.Graph().as_default(): - # Disable control dependency optimization. Otherwise we cannot use the - # trick to force the parsing to happen without using its output. e.g.: - # with tf.control_dependencies([parsed_tensors]): - # return tf.constant(1) - with tf.compat.v1.Session( + with ( + tf.Graph().as_default(), + # Disable control dependency optimization. Otherwise we cannot use the + # trick to force the parsing to happen without using its output. e.g.: + # with tf.control_dependencies([parsed_tensors]): + # return tf.constant(1) + tf.compat.v1.Session( config=tf.compat.v1.ConfigProto( graph_options=tf.compat.v1.GraphOptions( rewrite_options=rewriter_config_pb2.RewriterConfig( dependency_optimization=rewriter_config_pb2.RewriterConfig - .OFF)))) as sess: + .OFF)))) + as sess + ): benchmark_fn = get_benchmark_fn(sess, *fn_args) for input_data in self._data[data_key]: From de27b08f26ac649dc9d26e2897565255a13324e1 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 14 Aug 2025 20:05:03 -0700 Subject: [PATCH 07/17] Make compliant with Ruff rule SIM108 --- struct2tensor/ops/struct2tensor_ops.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/struct2tensor/ops/struct2tensor_ops.py b/struct2tensor/ops/struct2tensor_ops.py index cff8385..ff5414a 100644 --- a/struct2tensor/ops/struct2tensor_ops.py +++ b/struct2tensor/ops/struct2tensor_ops.py @@ -306,10 +306,7 @@ def parse_proto_map( keys_needed_as_list = list(keys_needed) value_fd = map_entry_descriptor.fields_by_name["value"] - if tf.is_tensor(backing_str_tensor): - backing_str_tensor = [backing_str_tensor] - else: - backing_str_tensor = [] + backing_str_tensor = [backing_str_tensor] if tf.is_tensor(backing_str_tensor) else [] values, parent_indices = gen_decode_proto_map_op.decode_proto_map_v2( map_entries, map_entry_parent_indices, backing_str_tensor, From 14885169e86a058dba504a52aaa2e1d302f020ed Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 14 Aug 2025 20:06:20 -0700 Subject: [PATCH 08/17] Make compliant with Ruff rule SIM105 --- struct2tensor/tools/build_docs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/struct2tensor/tools/build_docs.py b/struct2tensor/tools/build_docs.py index 3881fa3..8ad3f36 100644 --- a/struct2tensor/tools/build_docs.py +++ b/struct2tensor/tools/build_docs.py @@ -17,6 +17,7 @@ --output_dir=$(pwd)/struct2tensor/opensource_only/g3doc/api_docs/python """ +import contextlib import inspect import pathlib import shutil @@ -89,11 +90,9 @@ def splice(name, tmp_dir): shutil.rmtree(output_dir / name, ignore_errors=True) shutil.copytree(tmp_dir / name, output_dir / name) shutil.copy(tmp_dir / f"{name}.md", output_dir / f"{name}.md") - try: + with contextlib.suppress(FileNotFoundError): shutil.copy(tmp_dir / name / "_redirects.yaml", output_dir / name / "_redirects.yaml") - except FileNotFoundError: - pass shutil.copy(tmp_dir / name / "_toc.yaml", output_dir / name / "_toc.yaml") splice("s2t", s2t_out) From 76028b0f1ae1d4c91a45f45a3ffdf0de65fc09d4 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 14 Aug 2025 20:07:23 -0700 Subject: [PATCH 09/17] Make compliant with Ruff rule SIM110 --- struct2tensor/ops/file_descriptor_set.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/struct2tensor/ops/file_descriptor_set.py b/struct2tensor/ops/file_descriptor_set.py index 73dbbaf..db3f4b4 100644 --- a/struct2tensor/ops/file_descriptor_set.py +++ b/struct2tensor/ops/file_descriptor_set.py @@ -30,10 +30,7 @@ def _are_dependencies_handled(file_descriptor: descriptor.FileDescriptor, dependencies: Set[descriptor.Descriptor]) -> bool: """Returns True iff dependencies of descriptor are in dependencies.""" - for dependency in file_descriptor.dependencies: - if dependency not in dependencies: - return False - return True + return all(dependency in dependencies for dependency in file_descriptor.dependencies) def _order_dependencies(file_descriptors: Set[descriptor.FileDescriptor] From 390efe0e7f4befdb2d5ef80050468488f36160ac Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 14 Aug 2025 20:09:44 -0700 Subject: [PATCH 10/17] Make compliant with Ruff rule SIM101 --- struct2tensor/calculate.py | 3 +-- struct2tensor/expression_impl/proto.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/struct2tensor/calculate.py b/struct2tensor/calculate.py index b5001eb..1245757 100644 --- a/struct2tensor/calculate.py +++ b/struct2tensor/calculate.py @@ -274,8 +274,7 @@ def calculate(self, expected_type = self.expression.type actual_value = self.value if expected_type is None: - if not (isinstance(actual_value, prensor.RootNodeTensor) or - isinstance(actual_value, prensor.ChildNodeTensor)): + if not isinstance(actual_value, (prensor.RootNodeTensor, prensor.ChildNodeTensor)): raise self._create_value_error() elif isinstance(actual_value, prensor.LeafNodeTensor): if expected_type != actual_value.values.dtype: diff --git a/struct2tensor/expression_impl/proto.py b/struct2tensor/expression_impl/proto.py index 61ce4cd..2fb1cdb 100644 --- a/struct2tensor/expression_impl/proto.py +++ b/struct2tensor/expression_impl/proto.py @@ -289,8 +289,7 @@ def calculate( options: calculate_options.Options, side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: [parent_value] = sources - if isinstance(parent_value, _ProtoRootNodeTensor) or isinstance( - parent_value, _ProtoChildNodeTensor): + if isinstance(parent_value, (_ProtoRootNodeTensor, _ProtoChildNodeTensor)): parsed_field = parent_value.fields.get(self.name_as_field) if parsed_field is None: raise ValueError(f"Cannot find {str(self)} in {str(parent_value)}") From 57a5b2b60b4e94830fb6d353cedfb84df3947885 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 14 Aug 2025 20:10:58 -0700 Subject: [PATCH 11/17] Make compliant with Ruff rule SIM118 --- struct2tensor/expression_add.py | 4 ++-- struct2tensor/expression_impl/map_prensor.py | 2 +- struct2tensor/expression_impl/map_prensor_to_prensor.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/struct2tensor/expression_add.py b/struct2tensor/expression_add.py index fd09fbe..547d64b 100644 --- a/struct2tensor/expression_add.py +++ b/struct2tensor/expression_add.py @@ -103,7 +103,7 @@ def known_field_names(self) -> FrozenSet[path.Step]: return self._origin.known_field_names().union(self._path_map.keys()) def __str__(self) -> str: - keys_to_add = ",".join([str(k) for k in self._path_map.keys()]) + keys_to_add = ",".join([str(k) for k in self._path_map]) return f"_AddPathsExpression({str(self._origin)}, [{keys_to_add}])" @@ -125,7 +125,7 @@ def add_paths(root: expression.Expression, Returns: a new tree with the nodes from the root and the new subtrees. """ - for p in path_map.keys(): + for p in path_map: if root.get_descendant(p.get_parent()) is None: raise ValueError(f"No parent of {p}") if root.get_descendant(p) is not None: diff --git a/struct2tensor/expression_impl/map_prensor.py b/struct2tensor/expression_impl/map_prensor.py index 132fdd7..d1ef1f0 100644 --- a/struct2tensor/expression_impl/map_prensor.py +++ b/struct2tensor/expression_impl/map_prensor.py @@ -184,7 +184,7 @@ def __init__(self, origin: expression.Expression, def _get_source_paths(self) -> Sequence[path.Path]: """Returns the source paths in a deterministic order.""" - result = [k for k in self._origin.get_known_descendants().keys()] + result = [k for k in self._origin.get_known_descendants()] result.sort() return result diff --git a/struct2tensor/expression_impl/map_prensor_to_prensor.py b/struct2tensor/expression_impl/map_prensor_to_prensor.py index d95ecc5..806be00 100644 --- a/struct2tensor/expression_impl/map_prensor_to_prensor.py +++ b/struct2tensor/expression_impl/map_prensor_to_prensor.py @@ -437,7 +437,7 @@ def schema(self): def _get_source_paths(self) -> Sequence[path.Path]: """Returns the source paths in a deterministic order.""" - result = [k for k in self._origin.get_known_descendants().keys()] + result = [k for k in self._origin.get_known_descendants()] # In order to make certain that the source_paths are in a deterministic # order, we sort them here. result.sort() From d58da3014aeedb6d3de9e645bce2fae63c089b46 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 14 Aug 2025 20:15:14 -0700 Subject: [PATCH 12/17] Make compliant with Ruff rule D200 --- struct2tensor/benchmarks/struct2tensor_benchmark.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/struct2tensor/benchmarks/struct2tensor_benchmark.py b/struct2tensor/benchmarks/struct2tensor_benchmark.py index 639434a..bf0c268 100644 --- a/struct2tensor/benchmarks/struct2tensor_benchmark.py +++ b/struct2tensor/benchmarks/struct2tensor_benchmark.py @@ -11,10 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -r"""Benchmarks for struct2tensor. - - -""" +r"""Benchmarks for struct2tensor.""" import tensorflow as tf from absl.testing import parameterized From e5d63590fc67d185ffef86b548c7190c8de97258 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 14 Aug 2025 20:18:44 -0700 Subject: [PATCH 13/17] Make compliant with Ruff rule RET505 --- setup.py | 7 ++- struct2tensor/calculate.py | 13 ++--- struct2tensor/create_expression.py | 19 ++++---- struct2tensor/expression_impl/broadcast.py | 5 +- struct2tensor/expression_impl/map_prensor.py | 3 +- struct2tensor/expression_impl/placeholder.py | 3 +- struct2tensor/expression_impl/proto.py | 5 +- .../raw_parquet_dataset_test.py | 16 +++---- .../expression_impl/slice_expression.py | 30 ++++++------ struct2tensor/ops/struct2tensor_ops.py | 3 +- struct2tensor/path.py | 4 +- struct2tensor/prensor.py | 29 +++++------- struct2tensor/structured_tensor_to_prensor.py | 47 +++++++++---------- 13 files changed, 81 insertions(+), 103 deletions(-) diff --git a/setup.py b/setup.py index 1a3ec43..6212f14 100644 --- a/setup.py +++ b/setup.py @@ -46,12 +46,11 @@ def select_constraint(default, nightly=None, git_master=None): selector = os.environ.get('TFX_DEPENDENCY_SELECTOR') if selector == 'UNCONSTRAINED': return '' - elif selector == 'NIGHTLY' and nightly is not None: + if selector == 'NIGHTLY' and nightly is not None: return nightly - elif selector == 'GIT_MASTER' and git_master is not None: + if selector == 'GIT_MASTER' and git_master is not None: return git_master - else: - return default + return default # Get the long description from the README file. diff --git a/struct2tensor/calculate.py b/struct2tensor/calculate.py index 1245757..0b187ea 100644 --- a/struct2tensor/calculate.py +++ b/struct2tensor/calculate.py @@ -208,8 +208,7 @@ def _fancy_type_str(is_repeated: bool, dtype: Optional[tf.DType]): def _node_type_str(node_tensor: prensor.NodeTensor): if isinstance(node_tensor, prensor.LeafNodeTensor): return _fancy_type_str(node_tensor.is_repeated, node_tensor.values.dtype) - else: - return _fancy_type_str(node_tensor.is_repeated, None) + return _fancy_type_str(node_tensor.is_repeated, None) class _ExpressionNode: @@ -465,10 +464,9 @@ def _create_node_if_not_exists(self, expr: expression.Expression if maybe_canonical in self._node_map: self._node[id(expr)] = self._node_map[maybe_canonical] return None - else: - self._node_map[maybe_canonical] = maybe_canonical - self._node[id(expr)] = maybe_canonical - return maybe_canonical + self._node_map[maybe_canonical] = maybe_canonical + self._node[id(expr)] = maybe_canonical + return maybe_canonical def _get_canonical_or_error(self, expr: expression.Expression ) -> expression.Expression: @@ -476,5 +474,4 @@ def _get_canonical_or_error(self, expr: expression.Expression node = self._get_node(expr) if node is not None: return node.expression - else: - raise ValueError("Expression not found: " + str(expr)) + raise ValueError("Expression not found: " + str(expr)) diff --git a/struct2tensor/create_expression.py b/struct2tensor/create_expression.py index 1c5436b..18bc9d9 100644 --- a/struct2tensor/create_expression.py +++ b/struct2tensor/create_expression.py @@ -112,7 +112,7 @@ def create_expression_from_prensor( return _DirectExpression( True, None, node_tensor, children, validate_step_format ) - elif isinstance(node_tensor, prensor.ChildNodeTensor): + if isinstance(node_tensor, prensor.ChildNodeTensor): return _DirectExpression( node_tensor.is_repeated, None, @@ -120,12 +120,11 @@ def create_expression_from_prensor( children, validate_step_format, ) - else: - # isinstance(node_tensor, LeafNodeTensor) - return _DirectExpression( - node_tensor.is_repeated, - node_tensor.values.dtype, - node_tensor, - children, - validate_step_format, - ) + # isinstance(node_tensor, LeafNodeTensor) + return _DirectExpression( + node_tensor.is_repeated, + node_tensor.values.dtype, + node_tensor, + children, + validate_step_format, + ) diff --git a/struct2tensor/expression_impl/broadcast.py b/struct2tensor/expression_impl/broadcast.py index ac63eca..1164212 100644 --- a/struct2tensor/expression_impl/broadcast.py +++ b/struct2tensor/expression_impl/broadcast.py @@ -209,9 +209,8 @@ def calculate( new_values = tf.gather(origin_value.values, index_to_values) return prensor.LeafNodeTensor(broadcasted_to_sibling_index, new_values, self.is_repeated) - else: - return prensor.ChildNodeTensor(broadcasted_to_sibling_index, - self.is_repeated, index_to_values) + return prensor.ChildNodeTensor(broadcasted_to_sibling_index, + self.is_repeated, index_to_values) def calculation_is_identity(self) -> bool: return False diff --git a/struct2tensor/expression_impl/map_prensor.py b/struct2tensor/expression_impl/map_prensor.py index d1ef1f0..f226b7b 100644 --- a/struct2tensor/expression_impl/map_prensor.py +++ b/struct2tensor/expression_impl/map_prensor.py @@ -253,8 +253,7 @@ def _as_leaf_node(sparse_tensor: tf.SparseTensor, is_repeated: bool, if options.sparse_checks: return _as_leaf_node_with_checks(sparse_tensor, is_repeated, required_batch_size) - else: - return _as_leaf_node_no_checks(sparse_tensor, is_repeated) + return _as_leaf_node_no_checks(sparse_tensor, is_repeated) def _map_prensor_impl( diff --git a/struct2tensor/expression_impl/placeholder.py b/struct2tensor/expression_impl/placeholder.py index f7a79c8..c6e3444 100644 --- a/struct2tensor/expression_impl/placeholder.py +++ b/struct2tensor/expression_impl/placeholder.py @@ -179,8 +179,7 @@ def calculate( # pytype: disable=signature-mismatch # overriding-parameter-typ raise ValueError("_PlaceholderRootExpression has no sources") if side_info: return mpp._tree_as_node(side_info) # pylint: disable=protected-access - else: - raise ValueError("_PlaceholderRootExpression requires side_info") + raise ValueError("_PlaceholderRootExpression requires side_info") def calculation_is_identity(self) -> bool: return False diff --git a/struct2tensor/expression_impl/proto.py b/struct2tensor/expression_impl/proto.py index 2fb1cdb..28faff3 100644 --- a/struct2tensor/expression_impl/proto.py +++ b/struct2tensor/expression_impl/proto.py @@ -597,9 +597,8 @@ def _get_any_child( field_message = desc.file.pool.FindMessageTypeByName(full_name_child) return _ProtoChildExpression(parent, field_message, False, field_name, backing_str_tensor) - else: - return _get_child_helper(parent, desc.fields_by_name.get(field_name), - field_name, backing_str_tensor) + return _get_child_helper(parent, desc.fields_by_name.get(field_name), + field_name, backing_str_tensor) def _is_map_field_desc(field_desc: descriptor.FieldDescriptor) -> bool: diff --git a/struct2tensor/expression_impl/raw_parquet_dataset_test.py b/struct2tensor/expression_impl/raw_parquet_dataset_test.py index 740c84f..cecfd9c 100644 --- a/struct2tensor/expression_impl/raw_parquet_dataset_test.py +++ b/struct2tensor/expression_impl/raw_parquet_dataset_test.py @@ -88,8 +88,7 @@ def _wrapper(): r = gn() if isinstance(r, tf.TensorArray): return r.stack() - else: - return r + return r return _wrapper @@ -97,14 +96,13 @@ def _wrapper(): if tf.executing_eagerly() or building_function: iterator = iter(dataset) return ta_wrapper(iterator._next_internal) + if requires_initialization: + iterator = tf.compat.v1.data.make_initializable_iterator(dataset) + self.evaluate(iterator.initializer) else: - if requires_initialization: - iterator = tf.compat.v1.data.make_initializable_iterator(dataset) - self.evaluate(iterator.initializer) - else: - iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) - get_next = iterator.get_next() - return ta_wrapper(lambda: get_next) + iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) + get_next = iterator.get_next() + return ta_wrapper(lambda: get_next) def assertDatasetProduces(self, dataset, diff --git a/struct2tensor/expression_impl/slice_expression.py b/struct2tensor/expression_impl/slice_expression.py index 8e42701..81507dc 100644 --- a/struct2tensor/expression_impl/slice_expression.py +++ b/struct2tensor/expression_impl/slice_expression.py @@ -192,15 +192,14 @@ def _get_mask(t: expression.Expression, p: path.Path, threshold: IndexValue, if threshold >= 0: return work_expr, mask_for_non_negative_threshold return work_expr, mask_for_negative_threshold - else: - def tf_cond_on_threshold(a, b): - return tf.cond(tf.greater_equal(threshold, 0), a, b) + def tf_cond_on_threshold(a, b): + return tf.cond(tf.greater_equal(threshold, 0), a, b) - return map_values.map_many_values(work_expr, p.get_parent(), [ - x.field_list[-1] - for x in [mask_for_non_negative_threshold, mask_for_negative_threshold] - ], tf_cond_on_threshold, tf.bool, path.get_anonymous_field()) + return map_values.map_many_values(work_expr, p.get_parent(), [ + x.field_list[-1] + for x in [mask_for_non_negative_threshold, mask_for_negative_threshold] + ], tf_cond_on_threshold, tf.bool, path.get_anonymous_field()) def _get_begin_mask(expr: expression.Expression, p: path.Path, begin: IndexValue @@ -257,12 +256,11 @@ def _get_slice_mask( if end is None: raise ValueError("Must specify begin or end.") return _get_end_mask(expr, p, end) - else: - if end is None: - return _get_begin_mask(expr, p, begin) - work_expr, begin_mask = _get_begin_mask(expr, p, begin) - work_expr, end_mask = _get_end_mask(work_expr, p, end) - return map_values.map_many_values( - work_expr, p.get_parent(), - [x.field_list[-1] for x in [begin_mask, end_mask]], tf.logical_and, - tf.bool, path.get_anonymous_field()) + if end is None: + return _get_begin_mask(expr, p, begin) + work_expr, begin_mask = _get_begin_mask(expr, p, begin) + work_expr, end_mask = _get_end_mask(work_expr, p, end) + return map_values.map_many_values( + work_expr, p.get_parent(), + [x.field_list[-1] for x in [begin_mask, end_mask]], tf.logical_and, + tf.bool, path.get_anonymous_field()) diff --git a/struct2tensor/ops/struct2tensor_ops.py b/struct2tensor/ops/struct2tensor_ops.py index ff5414a..af6ac78 100644 --- a/struct2tensor/ops/struct2tensor_ops.py +++ b/struct2tensor/ops/struct2tensor_ops.py @@ -112,8 +112,7 @@ def _get_field_descriptor(descriptor_type: descriptor.Descriptor, if path.is_extension(field_name): return descriptor_type.file.pool.FindExtensionByName( path.get_raw_extension_name(field_name)) - else: - return descriptor_type.fields_by_name[field_name] + return descriptor_type.fields_by_name[field_name] def parse_message_level( diff --git a/struct2tensor/path.py b/struct2tensor/path.py index bf44456..ff8fd25 100644 --- a/struct2tensor/path.py +++ b/struct2tensor/path.py @@ -67,7 +67,7 @@ def _compare_step(a: Step, b: Step) -> int: if aint == bint: if a > b: return 1 - elif a < b: + if a < b: return -1 return 0 @@ -394,7 +394,7 @@ def expand_wildcard_proto_paths( if path == ["*"]: return _get_all_subfields(proto_descriptor) # path prefixes ending with "*" will be expanded - elif path[-1] == "*" and len(path) > 1: + if path[-1] == "*" and len(path) > 1: paths_with_wildcards.append(path[:-1]) else: # partial matching paths (e.g., ["feature*"]) will raise an error diff --git a/struct2tensor/prensor.py b/struct2tensor/prensor.py index e7c4786..caa166e 100644 --- a/struct2tensor/prensor.py +++ b/struct2tensor/prensor.py @@ -720,22 +720,19 @@ def _get_dewey_encoding( if p.tail.is_repeated: return (tf.stack([p.tail.parent_index, positional_index], axis=1), tf.concat([parent_size, current_size], 0)) - else: - return tf.expand_dims(p.tail.parent_index, -1), parent_size - else: - parent_dewey_encoding, parent_size = _get_dewey_encoding( - _get_node_path_parent(p)) - if p.tail.is_repeated: - positional_index_as_matrix = tf.expand_dims( - p.tail.get_positional_index(), -1) - indices = tf.concat([ - tf.gather(parent_dewey_encoding, p.tail.parent_index), - positional_index_as_matrix - ], 1) - size = tf.concat([parent_size, current_size], 0) - return (indices, size) - else: - return tf.gather(parent_dewey_encoding, p.tail.parent_index), parent_size + return tf.expand_dims(p.tail.parent_index, -1), parent_size + parent_dewey_encoding, parent_size = _get_dewey_encoding( + _get_node_path_parent(p)) + if p.tail.is_repeated: + positional_index_as_matrix = tf.expand_dims( + p.tail.get_positional_index(), -1) + indices = tf.concat([ + tf.gather(parent_dewey_encoding, p.tail.parent_index), + positional_index_as_matrix + ], 1) + size = tf.concat([parent_size, current_size], 0) + return (indices, size) + return tf.gather(parent_dewey_encoding, p.tail.parent_index), parent_size def _get_sparse_tensor_from_leaf_node_path(p: _LeafNodePath) -> tf.SparseTensor: diff --git a/struct2tensor/structured_tensor_to_prensor.py b/struct2tensor/structured_tensor_to_prensor.py index 2f220d8..d607e7a 100644 --- a/struct2tensor/structured_tensor_to_prensor.py +++ b/struct2tensor/structured_tensor_to_prensor.py @@ -76,13 +76,12 @@ def structured_tensor_to_prensor( return prensor.create_prensor_from_root_and_children( prensor.RootNodeTensor((st).nrows()), {default_field_name: child_prensor}) - elif st.rank == 1: + if st.rank == 1: return prensor.create_prensor_from_root_and_children( prensor.RootNodeTensor((st).nrows()), _structured_tensor_prensor_map(st, default_field_name)) - else: - # st is a scalar StructuredTensor. - return structured_tensor_to_prensor(_expand_dims(st, 0), default_field_name) + # st is a scalar StructuredTensor. + return structured_tensor_to_prensor(_expand_dims(st, 0), default_field_name) def _structured_tensor_prensor_map( @@ -115,7 +114,7 @@ def _expand_dims_nonnegative_axis(axis, rank): # Note: this is unreachable in the current code. raise ValueError("Axis out of range: " + str(axis)) return new_axis - elif axis > rank: + if axis > rank: # Note: this is unreachable in the current code. raise ValueError("Axis larger than rank: " + str(axis) + " > " + str(rank)) return axis @@ -148,16 +147,15 @@ def _expand_dims(st, axis): nrows = st.nrows() return st.partition_outer_dimension( RowPartition.from_uniform_row_length(nrows, nrows)) - elif nn_axis == 1: + if nn_axis == 1: # Again, by partitioning the first dimension into vectors of length 1, # we can solve this problem. nrows = st.nrows() return st.partition_outer_dimension( RowPartition.from_uniform_row_length( tf.constant(1, dtype=nrows.dtype), nrows)) - else: - # Note: this is unreachable in the current code. - raise ValueError("Unimplemented: non-negative axis > 1 for _expand_dims") + # Note: this is unreachable in the current code. + raise ValueError("Unimplemented: non-negative axis > 1 for _expand_dims") def _structured_tensor_field_to_prensor( @@ -167,8 +165,7 @@ def _structured_tensor_field_to_prensor( """Creates a ChildNodeTensor from a field in a structured tensor.""" if isinstance(field_value, structured_tensor.StructuredTensor): return _structured_tensor_to_child_prensor(field_value, default_field_name) - else: - return _to_leaf_prensor(field_value, default_field_name) + return _to_leaf_prensor(field_value, default_field_name) def _row_partition_to_child_node_tensor(row_partition: RowPartition): @@ -198,7 +195,7 @@ def _structured_tensor_to_child_prensor( return prensor.create_prensor_from_root_and_children( _row_partition_to_child_node_tensor(row_partition), _structured_tensor_prensor_map(child_st, default_field_name)) - elif len(row_partitions) > 1: + if len(row_partitions) > 1: row_partition = row_partitions[0] child_st = st.merge_dims(0, 1) return _one_child_prensor( @@ -232,10 +229,9 @@ def _to_leaf_prensor_helper(rt: tf.RaggedTensor, values = rt.values leaf = prensor.LeafNodeTensor(row_partition.value_rowids(), values, True) return prensor.create_prensor_from_root_and_children(leaf, {}) - else: - return _one_child_prensor( - row_partition, _to_leaf_prensor_helper(rt.values, default_field_name), - default_field_name) + return _one_child_prensor( + row_partition, _to_leaf_prensor_helper(rt.values, default_field_name), + default_field_name) def _partition_if_not_vector(values: tf.Tensor, dtype: tf.dtypes.DType): @@ -291,16 +287,15 @@ def _fully_partitioned_ragged_tensor(rt: Union[tf.RaggedTensor, tf.Tensor], rt = rt.with_row_splits_dtype(dtype) flattened_values = _partition_if_not_vector(rt.flat_values, dtype=dtype) return rt.with_flat_values(flattened_values) - else: - rt_shape = rt.shape - assert rt_shape is not None - rt_rank = rt_shape.rank - assert rt_rank is not None - if rt_rank < 2: - # Increase the rank if it is a scalar. - return _fully_partitioned_ragged_tensor(tf.expand_dims(rt, -1)) - return tf.RaggedTensor.from_tensor( - rt, ragged_rank=rt_rank - 1, row_splits_dtype=dtype) + rt_shape = rt.shape + assert rt_shape is not None + rt_rank = rt_shape.rank + assert rt_rank is not None + if rt_rank < 2: + # Increase the rank if it is a scalar. + return _fully_partitioned_ragged_tensor(tf.expand_dims(rt, -1)) + return tf.RaggedTensor.from_tensor( + rt, ragged_rank=rt_rank - 1, row_splits_dtype=dtype) def _to_leaf_prensor(rt: Union[tf.RaggedTensor, tf.Tensor], From 399ec56d34ced342c4c100adda49a62383e217a1 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 14 Aug 2025 20:20:09 -0700 Subject: [PATCH 14/17] Make compliant with Ruff rule RET504 --- struct2tensor/expression_impl/map_prensor_to_prensor.py | 3 +-- struct2tensor/expression_impl/parquet.py | 3 +-- struct2tensor/path.py | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/struct2tensor/expression_impl/map_prensor_to_prensor.py b/struct2tensor/expression_impl/map_prensor_to_prensor.py index 806be00..6f1310e 100644 --- a/struct2tensor/expression_impl/map_prensor_to_prensor.py +++ b/struct2tensor/expression_impl/map_prensor_to_prensor.py @@ -253,8 +253,7 @@ def map_prensor_to_prensor( source.get_child(k): prensor_child.get_child_or_error(k) for k in prensor_child.known_field_names() } - result = expression_add.add_paths(root_expr, paths_map) - return result + return expression_add.add_paths(root_expr, paths_map) ##################### Implementation Follows ################################### diff --git a/struct2tensor/expression_impl/parquet.py b/struct2tensor/expression_impl/parquet.py index a9775df..256387c 100644 --- a/struct2tensor/expression_impl/parquet.py +++ b/struct2tensor/expression_impl/parquet.py @@ -145,12 +145,11 @@ def _get_column_path_to_index_mapping(self, metadata_file) -> Dict[str, int]: """ metadata = pq.ParquetFile(metadata_file).metadata - path_to_column_index = { + return { metadata.schema.column(index).path: index for index in range(metadata.num_columns) } - return path_to_column_index def _parquet_to_tf_type(self, parquet_type: str) -> Union[tf.DType, None]: """Maps tensorflow datatype to a parquet datatype. diff --git a/struct2tensor/path.py b/struct2tensor/path.py index ff8fd25..ec3824e 100644 --- a/struct2tensor/path.py +++ b/struct2tensor/path.py @@ -444,5 +444,4 @@ def _proto_glob( raise ValueError(f"Field name {field_name} does not exist.") proto_descriptor = proto_descriptor.fields_by_name[field_name].message_type expanded_leaves = _get_all_subfields(proto_descriptor) - expanded_paths = [parent + child for child in expanded_leaves] - return expanded_paths + return [parent + child for child in expanded_leaves] From 535fee2e56bdee1b37f6373b4d148c9024b1412e Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 14 Aug 2025 20:24:00 -0700 Subject: [PATCH 15/17] Make compliant with Ruff rule PD013 --- struct2tensor/expression_impl/raw_parquet_dataset_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/struct2tensor/expression_impl/raw_parquet_dataset_test.py b/struct2tensor/expression_impl/raw_parquet_dataset_test.py index cecfd9c..66e3e7b 100644 --- a/struct2tensor/expression_impl/raw_parquet_dataset_test.py +++ b/struct2tensor/expression_impl/raw_parquet_dataset_test.py @@ -87,7 +87,7 @@ def ta_wrapper(gn): def _wrapper(): r = gn() if isinstance(r, tf.TensorArray): - return r.stack() + return r.melt() return r return _wrapper From ed83546c1a6225e0d7b5a2a7d15b669cbeb1ced5 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 14 Aug 2025 22:26:23 -0700 Subject: [PATCH 16/17] Remove bytecode --- .../__pycache__/__init__.cpython-311.pyc | Bin 1948 -> 0 bytes .../__pycache__/calculate.cpython-311.pyc | Bin 27468 -> 0 bytes .../calculate_options.cpython-311.pyc | Bin 2965 -> 0 bytes .../calculate_with_source_paths.cpython-311.pyc | Bin 6595 -> 0 bytes .../__pycache__/expression.cpython-311.pyc | Bin 33258 -> 0 bytes .../__pycache__/expression_add.cpython-311.pyc | Bin 10636 -> 0 bytes struct2tensor/__pycache__/path.cpython-311.pyc | Bin 19963 -> 0 bytes .../__pycache__/prensor.cpython-311.pyc | Bin 45082 -> 0 bytes .../__pycache__/__init__.cpython-311.pyc | Bin 1314 -> 0 bytes .../__pycache__/apply_schema.cpython-311.pyc | Bin 9657 -> 0 bytes .../__pycache__/broadcast.cpython-311.pyc | Bin 14908 -> 0 bytes .../__pycache__/depth_limit.cpython-311.pyc | Bin 4244 -> 0 bytes .../filter_expression.cpython-311.pyc | Bin 20909 -> 0 bytes .../__pycache__/index.cpython-311.pyc | Bin 10422 -> 0 bytes .../__pycache__/map_prensor.cpython-311.pyc | Bin 18599 -> 0 bytes .../map_prensor_to_prensor.cpython-311.pyc | Bin 22795 -> 0 bytes .../__pycache__/map_values.cpython-311.pyc | Bin 8911 -> 0 bytes .../__pycache__/parquet.cpython-311.pyc | Bin 28931 -> 0 bytes .../__pycache__/placeholder.cpython-311.pyc | Bin 12019 -> 0 bytes .../__pycache__/project.cpython-311.pyc | Bin 6338 -> 0 bytes .../__pycache__/size.cpython-311.pyc | Bin 8041 -> 0 bytes .../ops/__pycache__/__init__.cpython-311.pyc | Bin 185 -> 0 bytes .../file_descriptor_set.cpython-311.pyc | Bin 6135 -> 0 bytes ...riptor_set_test.cpython-311-pytest-8.3.3.pyc | Bin 2030 -> 0 bytes .../gen_decode_proto_map_op.cpython-311.pyc | Bin 695 -> 0 bytes .../gen_parquet_dataset.cpython-311.pyc | Bin 640 -> 0 bytes .../struct2tensor_ops.cpython-311.pyc | Bin 14312 -> 0 bytes ...tensor_ops_test.cpython-311-pytest-8.3.3.pyc | Bin 49897 -> 0 bytes 28 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 struct2tensor/__pycache__/__init__.cpython-311.pyc delete mode 100644 struct2tensor/__pycache__/calculate.cpython-311.pyc delete mode 100644 struct2tensor/__pycache__/calculate_options.cpython-311.pyc delete mode 100644 struct2tensor/__pycache__/calculate_with_source_paths.cpython-311.pyc delete mode 100644 struct2tensor/__pycache__/expression.cpython-311.pyc delete mode 100644 struct2tensor/__pycache__/expression_add.cpython-311.pyc delete mode 100644 struct2tensor/__pycache__/path.cpython-311.pyc delete mode 100644 struct2tensor/__pycache__/prensor.cpython-311.pyc delete mode 100644 struct2tensor/expression_impl/__pycache__/__init__.cpython-311.pyc delete mode 100644 struct2tensor/expression_impl/__pycache__/apply_schema.cpython-311.pyc delete mode 100644 struct2tensor/expression_impl/__pycache__/broadcast.cpython-311.pyc delete mode 100644 struct2tensor/expression_impl/__pycache__/depth_limit.cpython-311.pyc delete mode 100644 struct2tensor/expression_impl/__pycache__/filter_expression.cpython-311.pyc delete mode 100644 struct2tensor/expression_impl/__pycache__/index.cpython-311.pyc delete mode 100644 struct2tensor/expression_impl/__pycache__/map_prensor.cpython-311.pyc delete mode 100644 struct2tensor/expression_impl/__pycache__/map_prensor_to_prensor.cpython-311.pyc delete mode 100644 struct2tensor/expression_impl/__pycache__/map_values.cpython-311.pyc delete mode 100644 struct2tensor/expression_impl/__pycache__/parquet.cpython-311.pyc delete mode 100644 struct2tensor/expression_impl/__pycache__/placeholder.cpython-311.pyc delete mode 100644 struct2tensor/expression_impl/__pycache__/project.cpython-311.pyc delete mode 100644 struct2tensor/expression_impl/__pycache__/size.cpython-311.pyc delete mode 100644 struct2tensor/ops/__pycache__/__init__.cpython-311.pyc delete mode 100644 struct2tensor/ops/__pycache__/file_descriptor_set.cpython-311.pyc delete mode 100644 struct2tensor/ops/__pycache__/file_descriptor_set_test.cpython-311-pytest-8.3.3.pyc delete mode 100644 struct2tensor/ops/__pycache__/gen_decode_proto_map_op.cpython-311.pyc delete mode 100644 struct2tensor/ops/__pycache__/gen_parquet_dataset.cpython-311.pyc delete mode 100644 struct2tensor/ops/__pycache__/struct2tensor_ops.cpython-311.pyc delete mode 100644 struct2tensor/ops/__pycache__/struct2tensor_ops_test.cpython-311-pytest-8.3.3.pyc diff --git a/struct2tensor/__pycache__/__init__.cpython-311.pyc b/struct2tensor/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 86f3ebe5b213591fa8b921da909e38d8d2432518..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1948 zcmZ{kU2hsk6ozMw4Hz3RU<~+^#%?Rp$fUXGhZLoXnxst=MN!kM>BVTZJD9C!nAPl% zkN%PVgsMNL(q68(?N#rPau+F5&Kc~4F$TRlJTvp2nZwRu|1OtH1U`RE-w*Du6Y?Jp zE<_ zCCf#zAvfu!<#kb(6mpthXlHqj>d)}lT_#zk!-*6Ud zyOhA}!Qw1?6`3JB(|$YzL*tAa7V?2V7KW+VL~4Y2M>u^koF&L;6iHvOV90}Ww45y4 zI#FXCKo!0jqVF)!9IN85p{^u>1K2<-d08bmN)JCS!RY1^{~!_^Dop|%#YSnCa0BlB zCGJ=&qhNjGVjZi2;p+2??-z#0Fs}YMj70cMh5R@B%wX!|GGq1>XT{>Vd=&a4!$x4n z-}U)*KCP5t{wRdU1pC4cV5~gl{=hk+XD6ZcSIBo3Yj_dbKRe?go836r>I;)t zkhLbUuM>X7>P#ljddHIjW1$Kdo3z|+`go^02W_QSAxw^zST6fRuc0-M*4{F!?Ya5J z(1q)^zS=3eh_1G~>!(K2@SY))<7LY%lCe0oouKjEq&LuBckS31dJ|X67%FIKUa`#X zM_19Xg`tL_j-i2J8$%PqWcS)kHa3yaJLqa*Xk+N0fA4z#OJ9unr0gvIL*7CLdG~XF zs3h;lp&V=Nzv(~V$+=N+|CQ3`aM1zpJ84%&jKTldFm@Whc~`tERX7&>p6Y{xeVdYR=9Q8AsjH0qkh;pqlhjp4XzB|0e&VI%mw9ERpSnWZ z?_hJ%Gx9Wbg*L2sUKu$~T`kUF-kp?uJ+F*BN?m2-r_@zOZl|s?aw~O}kzVR*F$>TC eJ|*AHDql%VZTV9j6Oskp&Di*rvsH|L7mv8TMi|J19{VZBPUY6eGw1 zf}jQZ`_8$~nW3#XL9o3V9iDsMkMDf%bN?(D^b2^dm;Uecv9p5kKj=k&0$SwpWl0d; z6Jo-o5Oc(wGmcs3q;poB6lbMLY1TFAnsraQXFZc12jz=1-dW$IkHsa#{gZwccOf2_ z46wKx@tVmR7Wd2qCu`-p$$I>6@CiO4wkzh9McIG7Q6ET$9uZ?AS^=7@?w9_nU#$o~4gPv-xx9H_8c{SEL)81I{j%%{@cRs!?Mu+j#Vmxiec-+>R zTW0awUoW?HT`fJ8uU8$$*Zw;n-xj08devh*9e}RhPlYZEY{fnoM8#^*R|ER$j5W%V z9Jt=4!$LY3>q6>Ir*K|~ZM!4Hy5*h^L_zon|M@eyi{))cUQI<_FU#A3ypY^m zR&O`U+ljoMioCv<=Otlqcl1XZxblS4P7GxKsNIUS0IX5y(d#pGA# z6gicOCljf#-+%d9JcX<)N>sTWno1_p(Rd;inx0QgrKv<{I;n)Fv?k+;tM*3iP3gtt z2OKJ=RY zXeyPQiesT-7-$+j*}Dr3#KUrUh)viW7h6Dk!J*KV`LsVelS5jYESgBC zs?;0|huHFC2dMCD5*sv?oSjW3{IhcUS~3>8CM&XS-h5>vG~mcpB|3L46ivj;RgCDf zvzFl75XT*qz&!o*kqtx{LOej9> zznz>9-NG!X&W%L!R_NBXD1Z&KQY<@&-NJlAm5oO=I319~SHna8_;jdxgVS+&CYB0e zDzsinC4M!Yh|YxOlw>-2Wqvx8Wn`U@ot;5S0x5v{cuY>D<5YA`R(PRs=()XJo<@jnR^GcL9cz!OK0Eo^3LgmC%Tuy~U=WyKRSEI9Y zGx88S`}*3ZhIn6CLV_nWkIvW{hqPU>w8?fKB@?n7LqCL#6m8e275x|k6fwe|6Ji{6 z7p3v|RJtf#0GcX#E=K3(u<%9iC2UL*n_Tq1B)>9`Vai3<<#`Nm&{1sA89yPuSYu=6UBg;pDKE^dPUzfh8l^{2;BT!6zk5~55<&$!!ceL{?q$o z1;HEQd%}Vv+mv?TMSHFr@eh#lk)A9pI?_~Edlpb{!TCd{9oN4lE{W6N{qA@dgmJ-M z|4nh)30yBlH_?HZF-nH7JxzZ6^@-)W1&!#FKavj ztO+z@4qnGVmK6PfheBUuYp1OPt+`Hkwe2B=F9bWp zuRk6|@UG)M;i|A8e64NK@uqVXDV5KH;{(+H$Y|po=iT;~gbxLriTbN@nv?2CQUUDC zgJRJc-d}XQTC6u`#>q^IfISq-{_45mYsp!8crG?OuPD*m!{c)5Mmjk+d@HHkNX=D|Y1!Y5{u$WEO?M6d7#3@0^! z;Fnn;=ReDT>tne<*Mq>&Mqp^|#e85`4GbH3k7;QJMa73d#U`40OrDN{oav~e7>rCQ zGB6dN@SvpBqNq|w0d0fQi68|A3(F8Xtq>zx?D~{ZDaT%PeuwB|%s4?NRN`b>xRWP9 zA2LG|R{H&q6l0jPHL|IxQgS#7+-77^HnDmcgrbVnZBK>&qck=|$bk_@CN? zAS?W$SMcvHxEofEti=h+jy)0(d_qCt*ytxnIYDrDq7v@hs8878@d<#>8v&>`9Cv_} z$_dw^cvV=G7Q}JkYlDlf1!)ykDxU=x0b%81CxUjA<`C}gu@P0N8d3c!7HiSH=t+Cb z))%A&_XkM-h!&DPi{7~KR$#%4Q4X*gzD0kUaEA6Q`0RbZ8Mx_C8W;R;20pFS3Dmwj zxD?QcI@Z8Gp9POb)c-doe_n>LKm(PJ$fJ9%pA0 zc$;N>EY8*qE_~^np3cv41{xqoCyil7bZ8-hpGNVwI;i@bm|Xefr_7y{-u9B9w4xk{dkc{`62%fAbcPTL z?9q;Vh4dIH0{Rfg56U*gGY3p0p-<#9$nYYyLzo0qICzr(*2b2i4{CRB)b7sL_N%r1 zMg~J2N)JlgCQbN*O?p2f220ZjM%jW17Qo=xPb#L{WKOpxLi5!|4Kh!&EP5b8r4?B& z`q&YQfD6?57={%;@2r>5cxo2{Q1zUDod4E6e{^ESv)Yjl>{SDMjbt|IL08c&zY2~h zR&;YNkm2m2n2FC7o$=UUO&RhodbOdaignSMnTX{)6(yRI5~c~d7}v?UHI0*-i^i2y zQKAY(7wbzIqT$vcpwpv08Uxe)sXMM5MWyRh`@bN-WJ`TQ?J4Z%dKZFx@T3|%nH?xJ zw`PY6wXH-R_OKuuD@aWbq|Obgb7kaTAODl%-#+>N$-FeEN`pD}+icudaJQ`7S|0?d zX*vl~({vJ~rs?GC`?7sYyMNZwzOrMjFW+)NZ8<>vP;Gs7u+Xx*Bsj$$7QB8V8(TWN z>G3V~zW(Z+SC?<+JR$rF+jeH-?>d)9-tpe^=G`5?Ea7+I$N5!P)S@ z*|Xv7Sv|5I&pCVY&eN*%bk2FY;A_u0+c|02OU)8uo3bvLaU`!H=M9LwR=QxSS(Y5& zt{kK?slMEHVp-0C**A@C!J`UBj1aG;snm2V8kR0Q##$;VA#}w9siBz6V!`B^mIVs3 zE4LYEV^y1ktI}_Nia80fHqTUzX46m;vd$8u?Bu1{5N)c7sLqUqXkOH^jVu$J%*gUH zh!9zhlcu4IF~yEWCkV1=-G~LJeF;rV$+w($w`qhL8Vw@ykgXYgcHAgBZn6#6LP6R_p ze~bv!jZh>$8C$4?QO?3ip}O|iPejaNZrPKzy^gloS^4?7cmg12lMlF6g_2&GhxUwt zp$3Ih^GXU?sAzcYipla7Nja>pL+wpJtSaHe%3OY}2U^RE&{$Dxwdr*h&YNPsYl& z0d#0jNxMu&6f5&x+`&Lxjw&Yjox1 zD&Ob~Bo;T9NwH8D(6Vw7fo+xQBGb{t)a?ijE&?T^@1%~Yy0Yr zT>IFHf5pF^`RP-6-*c+(xt#MkZJKnLBV03SJ6*>7{uP?CQSOo}xdVB6U0#MCNF$_JSOpHZ?|U zariPcVA9uu3v^`QQGgPsmIl}3~!(#9T3ITv#`2%@!SwA_N0YY}ga)c|_yiggi_gQA7b z92*M}Jb{dZcBIHa2uqiJrNQw zXKT=k0kC8#h)^Hwi8=9R0BOvLDT(pG&2Qe*m9#`f4wBApe+c97JnsX!Xabpdu(1Db$FLleMor9zegEArry{!rNt zh9eQO?L;CLw!*M5{7-6vxE0`AL2P!L!NZ#1FMv=7)*N~79@V?2B-*kW%uj%`4xqB~ z90DjJIB_f1fR_N6a0{|>1bHAOQc9kgW*}i5a?z`gF4Ml%xc@5}ND;Qg7;A;LuI%`o zOC_g1nuiT74;qFx8iw)>VYMM#av|dpP302`Q1(K}D|DceDzy|ET9EL37XLV^lnAsH zIyE*pC9aUcDKQVkhFr&A{#0i&793^?`?jTU=z;hQCoaY=SwsK&wLdadqNhxSLWjYWE6*C5c(H`IUy5#bOxeX49+sT zjHJ~vS&|UzVl*@cQ!05tuntL!3R4%gkWLyMS$L69=&}M|9X3jt)!`=ptMgF>t`)FP z6EkuR{F0rcBDw6G?F`XEu`{#{5$z25OvZh|IWOSM__Z^`HHMnX%n>a*uGr*kesuO! zf}s3L^@H;(beakw06SiHhX1}jvT|T0y6jkfaoKGoGCSc=?BseBh)SI0?qd)l>*J zF8$sw+(JzqAK&*6t)}lUtY7}|SM$4`QFlFq)O_6;6=7gp4UFfcaZcN;oHfxlb~yeT z5we4<;&0_+Qx^ypjsk6cWVBaS%0S$9@GdJQ>AihwDdWi5OTi5R^*=J&prbR9`LfAY zkA}WszAWm80;w8$TO@xc1HB>L`nt#^0HiWWU=czfLy)w$%GW7i=+Lc@xwI$}gI6UV z5+p%^m51>}h#Ld|v{eEC%lJQ3fjA#PfYT8i`fvLV{oqTRUE5dE@4uGs8dkf8H`}{b z#^1k?Zy!|K2RFMztIg`pgZb`5V0<=PJ63w%-}6rLUb5s9!Xu9aDct`F1trnt>aPR> zG8Qo)_z5C*RuTYVY%em62Uygg(q)XGk#1r?YiW~Rvg5=wiMHB<@`lOFzTtLrussi( z@63#GEfo1meTf_i&nik%fip=eH4oP)<_rjQmb#XlpMeqs9$5*QISmcd>Y6SfJC@Hn%DM%lHkHt`QDh-j5qTxYr(Qo$6fP=#YasQ0Lg+d%HD8#xl z5K?X+W(ANY-LTI9FCx<zF9C_rE`XJ#Wc;s@|9Qi~* zP!im(Bb?L*HEEIcoMPtPqAvx365Ppja$?FjkQfOG5DdD802xn4f4Mofn@Yj^w+L(_$IhqlRGelBbbRTeUh;_dm(Kw z_RyTV()DJm*A4Tm{S2hSpHj!K*DtHn`lKU|i7yHC|5N%P^u-=~A4(t0&aUfZkYUet z39dF%!xDAU4!W;5{tsU#8T-b&U75?QX)i*X_$d!*v4ScE>|!Zq%<;j z)5>7bVsF8c?<+rsy4pjo#*;VGfh0d1uoi8n`ERXZ#qoZP@yd=AsDy%26N9x3R}_6l zt0tx4!toppElI(bD4?w<)_yUWOiy4zE;D!1V%-_2q+(`*@?|P&iq&q;n7l%%1f{I+ zVo=9uOu8=CaGa)jH_`gAUbD*|u5RzOwE6x77ZV z`Px%z?J1x{&!CZbDqs7wTKjZKMB>BR#s{^18?}A;+5xq8AP4o73(2g48GzsNH!olK ze!tptH19j6`i_+Z$<<%*wXO`Qo%__zLmR$BIp3i|Wox3~A3;vZEd(3N246{&4XjDX z+)k6N4Af^~N;`$LzcG5*&s;AUd4%*(E;yJ1gY7W=W)iP5MiBSJ^6s^0>RNTSy`Nxg zI~v2;YCXh%s}1ieD?8trzc+6rvBibKNurjDt~9ygGXA`R>_UF!y;YN6c@BXzQ{5AR zt$Wv{^%}cBgIK=xkczN=M6Dmm2cA*`PvxYic=v5294Df)awJmpM`U2AqC=@I6m-Ojy}g`^-omRyM9pl$12J%9g_+BH)0@PyG*JSNbq8HAnU;S~|5T!ckq zi-0c-Eyoo88xKiP^gKQV$SBti-Vr7x*#)g&85==5LfLa&uS?oOpX@jFdYa}iP^DCj zNqlYPNxN8EQQlG-ZIafp4qC^0YhN)}%#XYP@^Jkw)__Ysji#z>qO-)HOpVSfaY3>CE zeVysW1j+^N2fra?LscY=5bb~5gMeEEO)VwUFJjyH6BD3b_&vm-;Bql7bFmR_#?UFj z&z`uA2x;2Nkp?@NnRfH2a($_;3EM~bq798g+uzVc6z$>m5b3YFa_q;T$~Iye4;wKt zjCOt*OXSFGA5Vh!fs!YN8xsoT*BKce@d#(3&rOA)GXEn~EFWgEp<*VPo#lhBC~}ZS zPaFoQI1~-_YM>`4_3)E$n0jeIVAHk{X<(x#v}C7S<$m3;$O@)0yt)b8CR>)-Yb&Uk zBl9=R0In&5r;-X_0vASNOz`7YF%wX#=)fqmnqh7xnY@9kC)HJ8no6CV744dIri$a& z61aC{3g}<6ZEKGDOqp!6p|cby8eX8)y^EkI!aB!o&?4vBXyJ=q-S%eFd=-66g-yVk z=_ub_BkKt5>Rz-0gKx7?KfIRC*Nv!kBiUyQfx5e+Z=ZVW)OUJTzVhv%_lNRyKkI4cs{sO4erdID|mw0-`}ih&9#l@YmTcm z$8+xEg+N30B3qoIivTM%QT@Q*N0V422mFvP;AlHd!3rtQEQ$aS^j!H=b3e+~*_G@S z%O{xv{;E{jCIs#>WC4%sGj}hN`W{B3|NI=k9}~~8`;LbGk8D6}irPs(M`s$kbe0yH zAG%2rCVSymhC-zJo+rh8R8uWd3peTV4m^}^BGsZ!q^$WWNQbwFvP4}#;n;mv!dNog z3~saT^vgTkwF;NrrqXak;Me6yCY#7~m%$O&&*<}k#+vCY+1X8guvUcDI5)I4m6gxQ$@$e2TrMgmxJJ293o(^C$6m1Qmw^rg@gC`zNy0G83^ z3bSBIA`$Ka#3>IF5G?f-T07rM{ENium-4McYU|Kapb%)y1-5V2cjmfA^YzEo`r|qO z@q)KW^@bjJ`!~G(Yf|33PxbE0v0uqYZ9(bDd*F17$-6{bZC6nYdFun@ya7Od)1v1A z0H0qJmjumSOEmFTxn>G_Ag9v0Aey?X1?NpL<)pDP(@Z7^jt2s?MdGT6ZYv&WT^tLp z^1F5bXKq7^Nbqh@nh=1`#&>Cr8fA|ptQ7U>we z9bGZkYz9IaI?8hKVH@%?Bjq4r=ruIGLN%p{Oh7V->c1q(?kxYUA6QP`%Y46Q^~LY* zzQ0>-8&KQAMke3dc~k-Ie47Tv0G{+&OvoA3G#X(OGRY?69^rRy zs{zCfT4zm*1}_cvTX}J5?Cu2~N64lQ=9g{$4UE;QPiG!Zm0e}}o@_{T-cxisKr&h1 ztG!VFpwlw15$nPeRbmT&imjW~@laFR7^TZM!A|Wex8yfdzhI4M|F=-U>@8mr%~Vxw zR}!==n*OEkNRIth4!kSEd(nLlMpG?`?~If8T5Wrn9wXDXMbER;OAlN8%|I=0byZSB zhY`)!9Z?Ymj;ev9Iq4|J0!)PKP_i`1*D3fW1^tRWzaZ6=oKA5UDJ6D267hQHC{OSP#DSIP)olk#0%Gge zK7LI-e2K-&gH91XcC8_`buYX=5Z!y z{AKkdnnkQUsHqjhtJl_EQ~RDS35fr6oc$>e`h*u8mmDRJPb8d3;lsL+S_j=KOK}U0 zoh!X78Fk0V`qdvNb34weo#)hs^I8Q`{+4`ZUEYEgwmQU7j3A>9kE?xW$(ImnB75L= zh=gM+0=6qt*bn28wAf?XIimmNZeFs>wDy>Ua@xtb0`X4Gq=%frs+WR8n9Fbm^D;*; z-;(gEW0I5}KIH_u%>2V|ru3k@>`avll`Za(E3iz0MFckqyC)n#s@+HB#uIKsLQ(6S@!u4o#1JvQ%!p1pBR}_;ZSbYTgA1n1NN=!RNDg$olAuI*; zGSlgBmvFE~jM1mnFV^bcBH?+^97LvOUq_J@?<1G`~;S4S?kvgLAb_EMTik@l45E=o2B38@V7TnjEJ z(+r48(^}{)m6auIgRv9+s(4`J_3MfD9&nlt#o)EqCn$Dn_-oNL|foTHIn^<1K_ ztpI({cUPt$X3r~xLkJ-;Z%z12O$-J&1!fW{9q^(!mz?9o_XWIiTFfANKfUgwfcScy zn`eORPABOsa9{PHt`=iQgL_`liDaC;X1V+U-ZKMR0(9BA$`WAz1*$L=sJ8}|NKh2l zNv0(^p07Ef)||+>PizJomiOj^+tuK9s8Rg&w3-WRIeV{2 zd0)5c>(2RhZ?+%#SjxAbRNGH31vl%quPo;4_NjIIAT9egXe9J@D@_wG`?yK>%LrJxXKU|m%UIF-Zrjfa+kq1pYA3y7HqBX5y&KkHDC zwlAEFgnrHZHQSlf=<7X7-)d%*EBOg@;!i;K{Krb)Qco(1<-$Il#fOv>Ph zE_xO`Pg<89c#pA8EnBXWyL`NNzqsI7e>3L7p~wG0cV_HzaxOWOyh@kzaJ=z3GwtF( z*O1YrKiFrOwtJS8Ruu+?u(FQ1+5}8mZ6;uHL5bbj>NH(Ht&D(GV{9KQZftewY6@pj zoRJe8QRvsI2~pv~3r7fa0)P;>mSP+)nWI}`0}J`Ih(N|ou5S@`cnVDbI|$BpxLoVO z-}jHK-ddNIJnuFwU%vO%d_#}g&;v&UYZ*?r2z-A|r__!R8QwdGWKzc|@H12X1TV^u zC}0*I(&2Gs4AG52FC*^7dbSl2AVr8dxSf9=@BjDClsI=I9)Wf<=XlmwC&qy+m~+}R@;X2_50O&2rf?7Fbm$4mK;k5 z3O@f*^sUyV)`GYFfp^D-cgL!9-@kU?elYJ1tKM+V8^#Um&XVA94HSai4}v>4f;;oU zU21R_FtppRy1O5^dpF#@tC#ccJ*s<8&b_D5*s3-{1d}|tu!LZ#w@@GYe&g!+{Y&|s zBPxFNPkrpj*B@JQ74|=s^R|DdkrX)jjv=*U26mGwN;QbRkj*w$VxAGx|%J3efSR#M)*%XcfnYxiAVcm54@2Mz5 zfiyq6kh}9U$NA?prYTeeimWc>K5S`&lLwbxN!aCMU2j(pyntxF^MZ;{x~NKY{m_(g z$JK2I^U@(zI+W9Y9N&?M!?AyLSR~_xW_3Np^bdHie9H9?kV?hZI{^`x4}Yq=UzAIs zV3J3e4b_?XkDn_4kvgjK-ks*5gU6JPY`3ZzDq`(Y z+sxkr(&ekfXIk*m5K>N_{LHI-uMV=7Cg2y?sXN;<(1}k2gvc%dYix=$TEur*bthcZ z^v_0bUy-@?gbo%Bic?f3U=D5#Ol1=0UlLGFpat-d^+E$amM6Lnu>gW#Q`>{4gBwi; z*GKYAN7bgI+3N*gQ_i=e07=xlgI|zd&E&lYRqw%^_uwx}9=$S$pTlU^204Tq2=pI^ z&@DG~uHv=wDc5?u&M&`FUF|6^vB;I7PJGeAX$pzu_-esv!Ff)YS&bH9DDAbrFIc|z zhjasb+4Ll3AB!aZfFN|i*7Aar(plOQx58XB#OSM7Ci6~-L9)1Ao&lj1{D|^`s1e)} z{AlKWM%1*=o9V8$%sYtas+JD-QsbW96&duCL@ZEDkS0VJu)@U+U)M#uq)B0HzV1#i zO<$``(@jCGHN<%RA?^{+e1D9=lP3tfQ!+pYXOMkJl=YCBpPIt0J-u1u+o<}r*)j*G zB;BAfT7@@-Ueor~4yb_fS>0P~#6#q=HbD!7g`(N(8ARgA3x=zs(GDFaJ3%7IWdvyy z5yQUyYkx>o|J+#H7HgucqWLP8?DJb6XzdUe;07 z-=Ik5(C3)hf1p{cp*GDT1T)PWzyHnOUykKHU8<)m=jn%Q(bDD3*4?YIeCuA=p0bxp zl5HkzI)wwL)dSD6IYQvuK?2_nclHvzBjIl7ZNoLTaLzM;U!i;F`(IzXq;{Xk#_n9l z2U5o0X}{N=cXzUW9u6Eqa$-4>ckeEA4^m>xr(AUIUxHF=o9gN=`6z{2wFLY19=AI#!WDqR zLPN*O9<_5=ZP;HDDZ$<;E4X#9ESt^w=<<2o{R;Hv`i|rRM^x9*l8#gdC3N9jnc-e&O%hwRPpkud3nG z{IyExl^mn19yN3jN>c|sGwea_goFC89xhKmVmY5!^RGC^9S(SNR1c{N<$PkzpKuO4 kpjxa9R+5h>^%FBI<~Zwc00&nOS6+L>az5E2pCP3G2N=J_C;$Ke diff --git a/struct2tensor/__pycache__/calculate_options.cpython-311.pyc b/struct2tensor/__pycache__/calculate_options.cpython-311.pyc deleted file mode 100644 index 4fa90fec72c0dc2b849f5a22f8698a7eda3e593e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2965 zcmb_e(Qgz*7~kEywp?v(jffx;CWVlmlq*q$aww5lA(~KBnvnEmGq*F_Te`P<%>7#Khk>yLas!5My-Po1OXQn{U4F_xpP{_-^ddTUEMYBKj&;Qn_H9cz(!O7?tcUnz&yphw((xTp zyj@%>e?RUTRY{<=(9+)RQv(#($)(}_yGjqn^-ml?pPiCo~W21(s*1X0l7q3hM9 zx1ym`lVQACfw@c38uNolo=Vb)Z05NUaU@x7jWuEs_-izg7$@VVR7oI_O5aGj!SvkY zej-)wc^nx8Zqy~x*NrRTD_Bku86c7MejBWoCO{0^69ZJ-3WO9hEJ)aD5Qgwr$!5qs zOp_$3zXS!d^Q9IKLS74Hmzq_qhlPw785mU-_-s+7;1U;t5hoF0O#W%MO0FYypl6l{ z_PH{Q-pvOY1SP*THW$sLCSV!MY5BiybKOwAZ<`d<2eptRw-w0MJ@E+X#X>Tj)=o8~ z)_fVj2N9h&qs~ZOlk>b42g3Nw5N(t4orNI3LcRk!b|6nZj$-8^B}x2VPA^EcZ15-v zJiTXQAj&CsP{b=}1WC}6x&2<+{XmAoq%cOtGkQraBc!%#lT8N|Sut~&HYhk}VS z#p#`v%H@S}QBkE-bX&A12B_Aa`8WJ17urQFLx0zxru~8YvoB`raYN2FMI%*;ug%U& zy^_Su+0|IBAa+mA_FX3fqh96h*1Sq{t-a56QS%blU7yVBOQkEB#jcvl{-*WYp%a_) zJH-Q!yYIt{#s?qIeMs;3x6J=j=vrP6NDK%Q1G|q7E>AnzGrM8k?2YcD_pmowhqiF2 z9nkNkbkf}(vbRNu6rxxVLk&-MC?Iati^Gt zUcpi|iROW2%oukSVQrcN=;&)`HhX^^=Rje&V^MQ@sBrP#e8?NlqT;ckIrnB=C(_7JEgfTXUnrc#vshGg0OR=KPWQ3!8QGZ4(tfA}6zN3iH4{PqNvKKf*q@pNNIh9XwX+FxQ zg{Y7gqheZ$N@-8jllDfv9Bmg;zH~#hf#pTW`=frAmmnXA23Xz$`Cv52@?OY?q9K;| zr5e*s(I%E}NHwR!(J;&VQ!UX};2lV{rQ4(J9O4lc)Sw!|tuq~EaTC1Sh}&m6YsyWS z$HAGda!q$RgA)9d@ELs7P53)cw7UVRo7HA4;lRwMa&2{l8pa~_&uo57OAGeGsL)K$ z#x|($uJTCmHF=NPdP$1*s!h=6u`;SHPn2F%+b)UGKDFJ=U+uu*nf|gAYK!Vmc+^hl zDXU%3vuS2axveS(%;J_&RJ&os0PyZ+{07ToO8hoK>yYX{g(5u*`;H_p;Eb$fWPCA4 zu%;!mneB31NyTSVijHM{8p}|h(Xs?Osj`-xC2^=y^l5G68_GE1vBlStnr{1#DQT>p zp3SAO?K_^+fs>N5eW&pGS)7SuTR4SvTVm}I&Tgp80flg-5X-WG8dUnL5D*l%;{NThfmPHZ-6XGg3~yoYqm#c zpGBJCBY+qEB3sf4#+vOdNnv{k)@Mm3UfE>OSA@R+{|CPYat-Mmyh_hZC7(xm?iT0X zpg#pJi~ORX@|RFvSgGY%W`I6}+|5eqP+fUZ)V-BPlovmP`EFL`T$J*BUP|z=u)@Ol z8%%OdR^$|{j+{+Y){%YVcD9^MgR(Z8PAepdwUG!%+F^otg(nmw__xCU!CoNO&=ks} zGhK_^6TX{s%k-y?Joh1gwfhvhfg(I<1QvE{Qj4*PW7QSD!Q1@EPMf=^QMJhOcW;kQ zXVZ8zr>18KQRYVvVeOor&5h1w$vG{j#PO)6li9ew%UQ*`QguJgCH3i;vuk6{uFcJV z=if(l63?b{6AOE)eM?(Qi)|dVFO`LU+Qdk8|O2l0LG$-NE+Z}_dZFCDWc@2?T#s|*QiU2WiVE;YPZE`gQLyL zGDbY44S2Gpw0-5TvL%OAxL)^~-B1~-d0?Q-9@0nU8hVmkwK^qRp)37sGeDNmBR^{I zGuvLUJb~pO7X_}p?>++ZkdmK|E^WO$Srp3}v%(!~;T@~t9fk0y86IUix=;1_kP?G$ zedIw+gT~-kp>f=79AAo94dLbIjfNhB?;$cw=i02)JeYCjU<_=^cth9H_t^b z&;1ai&mBT%-uXbfBISh(9C<6xuTW>Q;khEYjfg{+r3A$7OFlhN z?X1)9EIsujRAB*H++6CN3CPm9D1sw;tPJ(eg+o`nuEz9Gc}9ftyqiC`9Z?wswI9MN z1c786;)mu$3OSjHr)Je;W{NtYMDilI8oytb50EKsFT7=^+umAdMg6g?XJy#svpLqg zjjrKTA~I4pp7AX407e8u00o~y{q4LnHxtTXsnpCdPdPz!63*dy z*Vc9U*wGQv*{)NFP-Fo7J&vX$0->nJ=8`s-)999wtFTtulNeSi0O+MEPeZRL+EniO3zjBc~{zRqGfjg<^IoWyJ@H7{$JB$ZTh0jneOQ6>i% ztO9?1j@rM&%aKn56f5iVIn0ckVBj{FOr(mFOm`xTi6t|METE^eX#md36nH!K z{ZweirBofu0Lwam2Mf4YE<@wDH|XRH^EQrBsaIsvQH6jMxsuXV>PU`71ba)2RB>$# z#LCsnXoojckulgL;2ieBi)|&+3%QbmkR3DXtKq9r##Mw(KcDP}5pU2@2Y~OsLK z*vrCE&Y?@nJ@Kg(<$8x1Ij|O)T#ZZ?B1g=~k%IWLDZXs51n8jBpKWe349qq-Mz*j2 z2`x0HPoWPaw!2R+ax5}(mwbz`vsa*fN_rfKZ^0n-!HT3RlD@zzM-E z4+ke8>PC*;5D45_Aa)94F(4w@d{V8c3`4W+QF1w)QK7NfaRRYy!ZAY47U?<9_Bx82 zw*xeWlnQKTYgBg>WH@op4p7Kk>aY-b9DoLtUxNEJUC#j*?H52GgjnsHi%9TxGjjRq zdx6m9sjHbnpw|rae%fLMd+rO|0L8lleg8&Wun+7Z*k=ty%z?4H;_ARYV_=^Z?z%ie zf8X9~d&&xLxvt(Fpip^&Lgfhxl_!?YtwUHzzn5NqwcwFWk8F5kE7WCpx|pT7jD#K{ zXb62k5qJqZ>~x>fYF$6%D{JjqS8k>P2i+@n3%S=C(8jZ{f*d&r6GkLLm#JpeXwYTr zN$7=3=~P#PJGOqm-mAx~IXe6#5PHlK{cy~x5Dy9jA4$mDWrbTl>RM4gN&R*E-_QSJ zuCVvGxff1docAOL+5`&f1X>g-Ln#xc#E57&*BqhP^@bz&B+P01SR1{1@G!?@i}cJ& zUV%1pn35xu&`?|tLdYO(p+u#vzXJjrr|3h$#-(AavA>8y!eMZflrE`7(I>%fdJJhp zeFJN~yH!yc#FFV;s&t*>?!%%OX`-Y3j;l_j zvU4`CC+w!FY<4PzM_8A$vk7NWJg3<`ob87zp7M2_cmfuIxfI(!S-F_Bz0S2FT+p3x z?j+g80|WSW=p?-ae7$rB$gqjsbb{iPDg$PYP|NTUYdZ=Ejx85EHaRNo1KGdx?Mr9X z*%Y22Q&6N3T00Kg9MsQo0Cc#6oPl0iM+c4nC#|DF!<{VjjNwie8Zz9;LVboiS!j#l zP8J$9+{r?_40p0nr%_JpXxMP4Vk2trUw*~t-(_uyT;Fx;c|#tzw(PjRV2&O#7}B#%!cuDNKZG{@f+Vt z+b7*-*E2-~rQ71|(|3My2OuXD>dBra?kERWMMle@*)UWTXmLH*X9l+yC6<47+{qLX zExPDXCns8+{VS3wkDHylivpC_rI69sZ%T4ef>IH=$(YEETMe5az`%QGw>fm!$-W9v N2&y1vJ)W=)_h0pFQKtX^ diff --git a/struct2tensor/__pycache__/expression.cpython-311.pyc b/struct2tensor/__pycache__/expression.cpython-311.pyc deleted file mode 100644 index 35e11d4dfd49f1de629d75224e0d35ea36a08d3d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 33258 zcmd6Qdu&`+df%NV=fUAie2J3CONta9qDB(+utdEqiF#1BC9Os8##}q&(cDXFC~}6r zGql86s^D&fEN!SYce7h_QzhX7T9>MYKnpm9iUz0yY+AJaqcg-sIhnw%(j-^}=wn>l z=wg8Y{e9=$=gd%+i!|*Ob?)Q5zVn^$eCPYV^PMyQG!zO5xTcDKJMljyLHJ90(J!y| z@cDhGAbcQ5!nhzgB&%EpqE+gQq`3MrdQQf_A{sa>U5YAs2*gQX;BOH;O$q-? z+e=b*$bMOrYo|K(zQA*r{ygE@sy_LE3RwklN*L*)i3lM?LEJ(2d^2e}0Yc zVkt$W?38zv)YmIHuL=EK(Oz`^>4YfXo>k;jDxOS;icF7kBApV`H)JsuosP{-N7EFQ zoDicTViT#PA||g-$+2`e5O_I$Q#QX3icwo}H=-#-rWIKhQ)y)`mY!4OLGjiNS)oj^ z8}Vt0m754yV2C5nbSh~tDlDE?k~5+ljoq-qn@kXp2?J0R8@Qa59 zF%gr+_{{9IJcIVJHVuj?S+-OiNz18pWG)?_4(lzKBkx3~=j2FgI(ch)zCSFUGmxc$ zc_aFc47?;~WhEV#Q-SpSETBk8VmuWAI^}3umVnG?T8vAmI6g7YP%9bfC=pAQKqnA7 z2=nNLx#@H~F==TI(EVCsnh>saL#h|Y9Eboi zlM31$C2XJ!F#(jBeUDlR`CX?I^Ig=kbP|Q4wV;7nB8BL*)^&tk>X86XGz0jfld?!H z3yWiD1tEoDIwelZ>B!AQ@>U|kx*bsT^%I%`W`rcCVsb)?g8C4r$e@-4Lq-?N*kyr# zG=TyWD13TQj03YN@jALL07%hs(*!i5PwkG<5QB<1=H*meL7z;-9y4+t?@SQ~k(d*N4!4Hbe$cbxmy5N2}o=O+0UWv}mqKty?iiW3x@0$Gf z98q4u^I8I^>vt6DP154%OyQ$@Hk!VHdzIF+cnrK8&CI`?j7qXXZ>iV~c_tc}y*@;V zk~7J)oay0%YADU9J#3Dl2+?_1OIGk-S5QnWno4JSpPj=P@XS$wUKBXym6%L4rx^ zmQ4S%vsjWAszD6X^AX-}nSp0#w5Bc8rKaOCIbwEaCj9K&rRihZB*1Pp+RNt$t4JV` z1qdR-xKnbByZnM*kS2C_$1$5&v-x%N`AaK zrNCXcUCbD-W^tR8jvT?V93i8$>aw^Rt+a{|p|nt0j#{I%nzFb$xqhlaZy8#!MH0-G zqqN$x9F5i-z}^+9PTnFnO*QMKnE9o8dwEiW+%nawr<1k-uT4`%a^pb&9G3DX0#f5$ zPYF)jtWao}Zz_Xzn=MtdL?yc?+)_(f%I&t4t@4hkb^}ho?6G0REp08!-(ky7!i=Q^ zvt__&TUp*t>8tA7cGo@LRhmi?(30(l*(vQnZQaj+w;ic^N>j1?9f;Wltn^ypMU2x1 zvxroES}L=i(o~&D)n5W-7y5X>EDL4qmIuxE39qyhx%Noia#-pC9aY4dv`gMA^)4h0t>s95w5LWIKt22a9_rbRdWNJyw5LlxAP)epgB77JhtVhPLEb}* zCtYPF*wci8%flw_q`f8Y)Xspk58UX5^4y@6{V4qiiy5-v$6gDS;*$;_?NR9cFaA;b+!hwU^Yhi!fG0@_chz0wh+KE`TAFIurN>9D@pa(`8C2Bf1X^|&;w)#8?p z;r&E;t2G#oBlhH`u_q9Fs)TNNOHU%^MV9ZB4P!6b-pcV8kWM4#Y3U5;R+z+WV{08R z`z%tNkw!4HoYm=rLF<;zA!dY)6w9bKdjUM;JW`yKz9C(}{k-%l(ppME{6)~f1#sr_ z(Km?x8%69zYT}ycv3<`An$I>Ta|YcQtxLgKlu35B&BTC<0lL_RV`IulbcnNU6`Lb13^x*cg4 zs63uX%L?foq^Rgc*b&A{nodHx7I}==ca7CvPrL%XAHN{LdEm8vOW^G4aR8&M@1Lo2k@F$TbO zpmyeZd~z;1mztgz=MoV1Q|V+<5}|K#nGdueO7e7k2CAHrGRo_d!;@jLE8GQ@6cC+< z#te{Bb5M^hRYgh9o`&iQb%#_^rqfTy(Ma@(Q5Kp#n@ptQP-;j|f-*ss;)(F1$N&r| zizep9=v+FvCzh0zm}U?Fh|qV?q@+<8DKaUJ!OQ)gLUqJS?H*3+)m#UR#GpIMg<4HU zj_9P{M9M-L?V~VvN}irjwt;Mw?G)^wfFzX^jSC`5@zt2J_ALAR9GL`CLrBAo8KV>*>6wk`Iy_)Z3sboB z3**v}IZdR85yppDDmnvO06H1{36vp*Y$1&HX2^N~?cUItd4g29-=T;o@S#H?+fj-t z6qQ{F8D2BZmWEn-PNB3R1Zb;QsIF6k+q3TNN*}&_E%=AI*+VmOnAkQ)Jh%suT~wTw zqZ3xH1Hd~8%-sZv33kCv>@!nqftxMWJs`hPxpslw+dvs+a8x*r{u^i;D zB9V?)O~Z-C2sLyw-bw_pvjTi3uFwS&X zs;|TYew%;@B3Klj25Z$|PuAU24ua9CROSalnlPqac0KBQj($hfjpz%qX7EfU=aiVN zS>e#VvSIDhW<9cMSxCUJ*YS*&^eJ&>E=AJ|nDWF#JjR%=#wQI65M#)e;<)wCRxe|$ zeXvtnZJc1`Mj0grNJ+27crklyBG+4qkgRpMV zH49^1s!zw3Y7JsClcvN8GJ99XXCEXLRu%*RjgbseQ>^B2MueIvwTxG>%vX>@xfLhz z0hUyI`m3QOikik$i11o(4@75aZJ^t-|Z))XT z!xWy5!=60IBtVN9nMpVhjxg(`$WE={I-MLeM}&AQ4jV9UG{?0hL6l{NYoZ&Hr18eA z;OJw`3T`w2)?yMnsbn$}YW>FwLa&m5WRL}4s32H~XUuW?vPQoE;O#kPXhs&Ja4MD+ zrKi$(vu+Q4uRbj|{!3_nCRa{5M6 z(nWKP84Qcu(vD%G^_Y%m>R$~sgI^4bP~CV=C<&TL7j@CnOiDz(5RjAN^*9O2Sbh>` zry(xhNKQj2H(F`JtYNT()1HRjVk|!yx}i>~C2g`6MkrFqR?_cbl9}=Xo|PjM9Hn3w zfhp!PiLKz*TPtT8&lyB%u3ddVdAL~d<#|UL$F)grUu)9LtlDHE63fk6*`UYZ(V2jT zs4=$c2Sj7U#nMcv&?hBY#^TC}i6DEKoN|#_A3hesG)ROKFzucdP`md1O-g3ew7zqfLXcD3^&}iX6CfYSL zUx>eh2>_`?rX6scSSL5|RC^5^x99;ziYSgyl_aw}QCE$bxPG-GJ5UgUv$6gM|!T!NsrQDkfV8rXz;gFTon)-iu&60(0dp zICcOsDwOQ2NTt=FHKhxlF3IrqsM6mB&ch1RteJq7lAFk5k8FrLk%88~e3If)KVjI~WbSc9-NiE6`k(PG#3X@d1|5fqksfTT zR!5TwaXvYRzJR$6ogu+*2mLq)+bJ)W(;Q<)ybG9-7HX9}A2G~0EOSu{g#8-tClJY# zMTA2DoHk(1C7Fj$i+VfZ%z{*w4xNa@Xk3|7oQBI73e&dvj-Z6Bw$25Ix!`jRR?t<8 z2}dtbz>3F*`2LX*cgIPq9%iWvj^jwYK)!B{UvIfUHXa*04`Y_I^~9C1kjqslS}$iw%sJhQ;sP>31owAgbVdNuFnhqkR?C49DHvcbSvB&kGXGBag%oHG@e@@Nk_IYlZEa2u+B|2Nc(tS-$v8k^`^a#1G%PSYSXb?^>MZOc-DK|L}|e>Ym*b)D3W3IRXh+@K|>Cc zf<}Zlx8HFsI5(j`mmNn$e~T{i#l0rn+rjYV{B`Gf=#xxuAq`mbZxfl2Ot$;<4Oonw zg#g22WEzt*Lu0?2QzL<;(E6T2beSS>L5~_i))1b|LT$!ddlVdqh16+eX6OUTnp=J} zT@+lNMiwju^0iGHwH@oV9jk#{?QXSp_u@p}U%S+vuidhgUP>2TLZJ0&Wy18KScXDU zkiI5ijyw-xrVX#8@*$65KsYJ8B-bOiy>mVU(SKwRzCBJOZjUmCf-Yy8Ux~vT7Q!EO z579Qwq2IalH?&3oE2)Pq#z9U0Cf*ByZ_>gGH3~9~Pg3km;Xg{7p`-_63YQLr9QJZlQ;#tQ} z<6-lTx0~GWIO1%KVPv&oWNpUTOGXy>x{G-Ey=^qAlp#EVcVEO!`35~}qlJB-(Nb`; z>6;)lhYRC6eB#orZDm}?9=(H*dme=>3i;OU8?E8>)^M(MpW3<)i%I!VEf8GS2^4t_ zuwcoZZ`!)iG`QY0xHgt+8d95v7AKa*zB|1*ZE_%+G)$DuQ1}jNB5C;Uri|T=#>x!I zo@PX_zw4-fzPxWdk z-aAzbRp@BXs--@?bc}qjeatd)XumhMfC;`bgpHk(6&AwlM4_#H%nUkPkgccCeN(5sW!4 zgWIqKQ6mY-ri#pbFeH)?=TIn2+YGR{EaR-<^M z^!X}=K3y6JiGl1?5PkJDl#9NEy z-Hv#2EMLjM7_ybgayYBT?J=&U%_#+;CRQz*nb^3F2!SMlJ`$$4w~H*u2@X?ICEY}oeZ0~ z_7el*Vh<+j?OrqZi|V?i(@Sq-V)h4?_HP8c)`MNSV7D6VF1l<9nI!qp!Q3SJy-E}< zDR|&ZGLuqhA+_MX3bJFd1#c{wh(*)=J~mqw8hK0YWN!*1nB%MiqR?P_jZnX`39sO$ zQIjGufp{@pOrBOGe92g-w=!C~2?c+Jq*%yr(_8kmv2~+y_j=>*T;m?KanDBMf%V1% zxyD0k{o>urVOL-Wp!roQ#2zFgCQ+B86uBZ@%qZtqg( z(pY})fhBLYwsWJlZ@so}t)5F#Ip1K`H&}Eb+0&P=E}s8~2k#wSesxV!!xyr_i@D%M z6`_07oFSBPv^HiiTD%z31hy%$nbwq06*zJyu)mDv(kY;dUyD7GdlM<&^(=A9&cY`h zcif2=B+qk|WU1ZGsGnIMT50Y;qxu${-`&69l)N{YsHW`ROu(}bQ>6)$O~q+Ls11Bos(Z(%UHyT~){F8qR-V~{m`w-`7ks)CUcslYf?dvdA>MxH&Tz)Z z?s!xP^BIpQVzKaVh@vWURa%+)K%hDH&6Ydx@m$4w`Bkcp{=j*!l{qrGMN|b}5E7YT zgPg#v;J9ry6@z<{I#kX*b$2l?TSsoN_FA~7k?U1r9o;B-QOGtP<$r6P%isLq%^$wG zxowFxRKP zLACdV?A8k_Z7Xe0#Tjm5;CF>PFPMk!WwVM?N0?dCkW)f9m?=(x24+qBa+}ss^rtOV;^VVqXYl zuDrGy{oV1E@w`8@;TPBaV$R>C`n!lncI16E?}y$CEl=ir?W(Um>ub-4wk*y(ZQZdF zT6QnH*A6}2|M2MI=+mD5jh^G{J;!rBC)J*l8$BcIJtMiE^J>reqTuu#V}ZuA^9>yv z4Tshn4n3awq(9g2vfA+SVm$Ay{{GDOX0-mE$3540R7L0=R=vYn_Gfb2LXcK35@|MJ zVNT;dC`=ts(Kw+EY)K_t@G#G{LJjp7SJp|zh*^g>?U7;9{30H{kI|9FUB4>J!tV-| zW!KVivm{Jl`|w@gg0qwnz?s**;DK}uNIo*4!5DM?#h6*% zxq;Q#w?~WKXdjQ%42s&0L)zKIo6Q^zv53f|M#D3MyP2k#fs|xrIB>w;F-bc=VN1Ia zoz=V^Alk8Fxbp>8NYYX4jvHk8Xbmq-hGV|Ar+I8>kdiPRgDi|S#we5j^by6&W9uI1 zEV4PF6)@7zVGC7iZiakUNub1vO&aTY*!#&gx-!eMUIZ(T)x%_Vy&8+NfClpz3c4xJ zW6cYIEAT32o&(yNr@q}0TPsP(vHH271K0|K$h|)F6`L5`tFU#=1PhB8VL6QDa8|<5 zNJ%yrlCfy5L(LQj=3E11A}a1LPt8gx$uU8m!eQ|hETSga3L+tgS&&UIbogvQn~j1N z!Pe$ zvv4}134n^RtsZF$;r3GW;XO=xl>P#8A@vA^n1T~(qR++>`t4o9lstwRCf730OE%sr zbvLu=V+axrvmv23#?nQ@v^yGI!u^l5rb8*yJ6+GhxDju+&81dfktc9O8q5}ErJ)II zOGh$K%@kf*WZ7g5f$;=vV^!vP?KXjcBqF9}k@4ln4oxIzzvHmFU7C{Vf@9M0u5*e> zcf=YW33fRM0z$_%@=hiZ7dCm)Q4}&TF%<WqQmZdP&9Wxt z;)7eEeSn4fvYNTB737?SO*IdzxJ_^o8I~Z@5>`Xb<0MuZ#~$T_f>!Dt5$ujE8B9u9@W=VblMuku<*303;L8l z0Da0IfIfv5`viY|-nS+13*~*ypEU)mp=Ti|3PF#TA>se$DN*@t;Nky1`6^4OMtPTl z_Ygb>nJUrANm+Upnf?K~QMpG&_91}9)1)+8WaI*dwk+ioWRG5?s(+wY zT?lFnku>}mL=5911A&Q zx-knSq;Ia9T#0uJNuTM2m+Km$yY1AuCX{&8l zJ932t0nFM^Sg3Mp8K1-Tq78wZ1C5oZRuGxV6m;)WHY{k0QT_=9M7Sv;BpMrpe2o=fVUOZh#=p1k(S+0Q)A z>QM&_Xb4Gf8+Gt0f`){CkpRJl$Zz7|?{k?Xe_i9kE;?E z&O+otCi^&C2PPE=qg+M}1Ff}y)%6EUzwc$<>|jSEJhw6yqqh-E8zj#6FSF6y!wujj0;N^&F;1Z=el-tKUO zCshUv$R@H7I85TEQK-q=&6~^Hc1t1T>lST617&&t@C^RkT~Xy89%rE3#KtfUglN|*8g1u8NEI~R-N zi)|fC<6`zRf-D4LtbXOG-};dqme`t0{O@7+TIbp~mfrqxx7yGJb8JJ8+R$Ufax!r% zw0nV>TuoX9(_mEdGm0Yo+M%QzHX~_8C{=>!B_!pXqH`m7cs+O+N)dhDt34$IxV`3yB%Hliw#nV?uE0}gHe4m$< zRN@eixx15QEzzwl%=Ds3d3$IZT17hy%v*a5;~n-wyaQ$ctE^U*#TY_+e+^lP`SAz_A$v6ti5g7SCVPF4qJE2ly9i2T zuR88phb}9V};MByp+v@j?4+&{hN1%Zo<-Ha5COnXw^uV;Q{;a(OsO*l@ZXQ zmB|(w%9GlR)aB8COT(cJ{d z0ECJg{MjaS(Xe1cL&O87EekV|N>~LBfN*ABJZl&oW64P+IX6puz^I4OiCL9C@0i^a z&g|6kMXt|t#4uYKHqaPbl4QxYx0jg-{s`bH|D1xm2+AaY(lRqPcESNQD#QuPLW*<% z;LnF0SlRZoxxk1T7|FUvG{j-OGLfR?QQI7EAdAdIAK*+WSZ4*h`k^`8awj-#Ui}vY zG4UhalIp?Y*7Us1T*SWMNM(Kh-qT>?5@v{ZPu)9}b+>V-yxc{_29g`aZpPXWRen`s zZQMuLp|l{mu4eu*o~=N$$V`hx7+@f4P8B$~Sx6C97KKlPSTyNg4|cDPtsTw~39bR^e;+dU4vp%)vLBkl?+E zVw^y|KXz)ilKq?6O5Sy1ZPI08wRCN=GJZ=f^RTcCVaU)D5Ug2DA~n`8l^OnZ0ypJ0 z1#eM6Qj})nT3oc4UrK#OnQZi1kV;`;U^U( zz;9B(PT^X>6ZelT15(LQDmx-{J0Drp3Lfr`{}!GZ7TyxY7daL_!6%z8PoA@76_5;* z$|C8n(B!AgEx%>kMoW0TC7f&7r?%|ldmnNwN7NQ*Jpqr81@8uz+)HxaU!(eWWc@vP zU!CSTRCL;!ZNp#dj*XU~^_HPr%R#l}V9}41pAnM3px`rCRTbPt5fsg!5z`qKQ$0Sd z^iKmd_xhL5FelS&;6g5NK@D8^%;lt<2w+J-0D-!p(k?KAH!g$2Y`(;1AJNarS$>~+ zEs`gxqq4`Ry;S&Wx=bN_u@P?ucS%)C&IR|B!671E^bWfcZSZq6_=5hG?$*-EID%`{ zn$&onEQq0|<7!8cSTkfTQJMrto3$RjuN) zaAZ%W;5&YvoJC-`CN2z56;z~on$BJy3ZLg~=A!&%Wey1HuQS<2!38rC=NaVdPzW-A zvIx#YOr#60ndohtu#N@a^H?BJ2>XSgRv?`bS@2#Pd-b)mW9)>rLT5y8yY2_6!>wZ& z{(M->jbWD@I|H&%r6pi&S2M>Hs`XsNLODbJ6IL;^81dGWNN38wMS{$rMRNE`P!1B4 z`VYtkuG}DCH_@j2Dd=e@Z;Ez}KQ)RQjos^w-MPkHYU8fOnK?Uv4{SyLy|5Me_fjo;pKjf;cvJN@XT8nE79C11 z1)*`qW$>u~h)>{Aj?__jR~ymRht2kX*~hh=8$1thx+Yk?&P~2#FBdYDP+5wBY~|JA zwAW4}pH{!3$MM?CsP2FnN_MQGXR%=utiastfFg7kCyv z?!DJo>1?Pe{~k#+yMLj%GX9u9wauiZ&~VE%Sw<;|54?^`G#uYTJR6SHD*n|^at5?Z6%K5z$V2$c0vdZ0)F@cry)MY{n`;IxCvd@ z#_gWMhB}4Z^mUz+_VUNww)@ zE_g}}LWw?Au7VO_G2#;OU^jl0>-9KLc1-7CWLAnB9xAdyB6Jcl#lf0&L&MkzKN*D#GF>WYC1_&Fku9Q>Y63#-6QvoEONgG zEO;Y0g9-g%T@c})=BPJFKtpdpgVB29$I43g4{rbH_Il7bw$ru} zZBSpBb;@SbzU07eVQ{FA@Sk7Av@&}z0n*Uc2n)v4L2a(l45@s&;o~!`{06cruOomr z0gb9MdaH`is6z9sC+=$BzDw=fid@hbgS8vM9qYjzt7o#o9l2ny8tl!wd)e6PcU~UT zwwqtZISmS}VdDtfvHq9oaA*>rBauQN5}8R#bJO%(9f?qnXfc35O2+WneH~qcH%jo=c>;0XiFc z1>c~6*)hIJPqfsf&_u=sdkIek7r44YDj(l8MO?5@dzMZMr_*P~l2_r{q7a1?s;@Ek z_%XH*T3MznKc>J>kzVcd@6t3|c?pV8BwQ}8JT7bxhXfH<*2W@#3@hKJPe0Mm=czr6EY z)_Ja2EgT-s@9ZrGTV02X0)nSoT8bWe@d|aVE6)2qwYIZZMIU@ZO9zgE>DaF}4;B6N zB_QnRG;^}-JUA%0tBX#jYsU(9nc`*!pSciq_Xvwbi9A{J^2L=`@Bapr1WMR!2A0&# zM(WPxG^N1t1$tmf)@dXYS8l9k)ZU|N`!J=1x}gV_xO!u4LEUv6rWh(!4=8!Fo}}2SrLPvkL&abhB}DMFxqao9hCvT~^a@Qo zR^Ct>2Z~kn>J#cZ)!M$IpI!n&L)%L4{k>}a052gRyl_%Ie9=QCb(PpDJAKUF&|- z|M-pUz-hG$Md4FvP_q@uzO^>?==CSn**zm_&$*(2Po=@(MnEJz0%&;1@7%q1_Q8pw zi=Gc1eX{f8*V!|Rv;_NlUEZ}TMFF>b-@eCYbnx)UNBFZ(7zjVwTJ+O*pKxTv!5^tmRx%G>b>RK)tuB}A~BClNv^IQ9hE_!&n zUA#Y9^spGO(71i&lDeh0SjFCbG|~pt`W`K(PZ&CBErb`Ph5MYYtB%zUVkcJ}pLG8M z6iQDOLZ2J0IaL&J`{dv+K%jWYi`@_WMHfAVhuNKdwglBq*GrC-+iSII2TZ>7@&t59 z_X?rU;~LEG!0OI}*B{qrdrql4UM#xs?P+Uo(L?vaBkWGZi4krDH4UyKYd|z^`Of~e zp+`s8c0X1=p3e@vsCJ$%y72vJyLkUb(Ss+@Pv0QML=}HNe8#~aDK{&E2YpW0v6Tn` zdF(NUA#SCC&*i%0SOZ4rj*bEYL3iTu@EQKVV{Bo-h0Y^T>7MW14LFbokB0yzJ@6P? z7^rid;4A?T#1e2T4Z7W~bF5stenrRL^USPM)X}x16HJ_iY+*d4)t3?11q@D2%c*N@Mt7TqGmX_4$1S zv_IMYTYyrvzX>nflkB}48V^WSa&XEprb|A#O7fFz%`g1s z5`GnMJcRG~^*%G5B?bEx!13B8;kIME4!S@Hzo)D2(elJs@}!-Qy?&pe5%OpCmV2im=8)fGR)w~z3j zUzJx;=pdN|SLtRZoB>BxLRaRO)*n$~KNX^>jOI2Q?Uzp2c47XrLD*s{$Kr4>koZrj zavjjBL)tcsa`T%NA+W>gh8(VtV>dGYnJMm$Y5d+2DfY~fK~JU`BQ*2huwmgdmsrss zgyNE%P{C*!kdCtRuzq6NjPCGAdZhV_Jq3Rx!iW>97~x;UrLB6B3X};^*BigG^ByT` z1Ne@$y)WAYPvE;3?q1k%H?F%Im%qK*`y`xoH{$pd)lFVz{Ev?XvI{VMjmhpd+X1Pa ztk+3{q2K7Xx;T`%voL)q)5xJB3nC*eb|~&o2sqN}a6nkFcIhCi*`^UA0k4Www`Y-; z1_F&Aq~a8p2l6hc}*#^mYgNCYJC zzwk9h$kmo_zo4WPe4k+ZH+VoN>i_c2m$J^6AfKN)ixv9j+FOER8ox@f86_E48e@ zkK^%|crJLshF~saN-oC}(KpwV8D@zCb1hd}n9_c_m+9yn(bR*50N*t~fgfsMom1#C zKj4R{z~vHQ_A}Mf_>JJmH2aZYg;p&K+s%X#ng#F+a#8vrV4TkV&OY9=|DDoRO-?7T z<0t@&Ld|q(=CMqUMBHJNY1Pg2%9Lkvhf)3;1w{k}-#4S!%?{;W`HWt@{8zyXRj1FK zeFb%M&6yFw1`4@px(5=$AxRELaeJM^4LJirx8Mja2Iya2 zxSIX?e@}(oS^HmJXwTaJ@0~fgC${dW%4KS z`|)i5m7M>o>c6_^eTrJAR(7matyMkrKl10dj-b}s^HeLs#UN^}Yg@{!gx7l4dLQ;b z>d&{GM~b?Olmg*mHBvOTFArzi_dV`?-23zXkNflOmyx9LDkVXf73#ncYPKvNfX&3e zI`rUJc1u_d?WHCqd=7{|mBZ!c)dLS8C{yfbEd0WnQmSuuU__S(OT(Yh*DoqWF>3wa DiLDw& diff --git a/struct2tensor/__pycache__/expression_add.cpython-311.pyc b/struct2tensor/__pycache__/expression_add.cpython-311.pyc deleted file mode 100644 index e214098ac501fc3d3d957444201f8bf14deb0861..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10636 zcmb_iYit`=cAg=JZ%VW**?L%3#_~&)6)JwL?a1;&wqwUHC0V;^R@s22IYWsuUzr)& z5lt2J!f3UD(0I3LtpqN@A3@fs1JnzoKzKDtw><~j(fmK+zco2 zk}bg}Z8Nr{ea4;?W`v|;#*uW+IC;vmCtOMQjGLteq&+hpmUbZRo$<1?GvQ16XZ%RJ zWV`H{@2EE?yXQOW>6op*p5!FA?39Jp+-4rr5U3Y@F@3GNEonjW+!JOt*|_tZ&K)7aXA=ZAR@a zQukFZ*b{MJ`Cj`VqsXcnPp4EdmR7`wB*jy+Vk|BvBsJvq;y#>-Xmjd>7)goQlp2Z2 z;*G2(X41E1C6-Nyw<8N;I@VTEybu!?(pfR0$l`4!u941&R}@q7ZL!&;sAg|O6A@J% z6XzngSQ$~pS3(wJk*Jnd7R02i&7~z#n~P{-G?GZDNXAuYie*z#4c!*ws+d(}^rY5X zEsd_SUTJADn$9fL1|!~xL~n|ChH=OdB@veuQB!0Y)qY2d7jzh&b4ofVr>@GH;kpvZ zWY87E{YHkGi6jj7Rr&p_oQleZ<61V8kb}J8l^fHCJAoz&-!4pPN=+-2G;6DdoAScR zNXBp~vX)g+QA~d05|APrdj9ti54bqD$V>d9EzcEkzyA6Hj~+K3pXJ^=^-~+iE!r1_ zyj`{Dg+<%4H*cSBv>>(Ynng$6UPS5jmt_9}H6Ax=Ejsg#Jb#N<9C0qs=N*%%k{2k| z+-G?wJzbv4+wy$O2CEWs6Q>n9a#O_`(Cmuw6ivCPq|=(1N=q`PUlOG_Te}Dr7^VX& zprUGMTA_~pj(QNZ8;-apCso7CY%!dOtD50dv$0tGgP=?Cqm1FM*E9m8KipgmgTHC; zx2Ue7YGGBAGsz#SJtQ zN&@%Nw`3(jc%1?eW=I*SPAl=*I1sM2>DzH_4k)W;{D8l97qk#G!qH>=1-&#g!zF0p+K)Cj{T zCL>zVZ8+3yMpg`WI2=#KwQ$(*#nrGPXK2eyhM+AV_a*37L+s50T_dIkz-?_I8J*^aO2 z0{)Z%>H)DYi;JQ~Vhc(*`hd%G^Y}KeS^g^bB|jbHmECyoC9mwomr4~xHXLK?tPBuN zXO*bjg3)SjbfXb&T=qPvq)97q75}E?%S$IpLO<)O(NWWkh0DsS;g@6;d?P|$TQyv1 z=DyWv!!V$O#=*w&HWF6T%=1TJ+uZx)2&Wa|pSd1v zZ#fDZ!KRZ%v#TNI=U{yHgq$+GZ&Sx#Qxq@;fz6%rGHNOz65|kct-6LvuA!&?t)ec8-8KZ3MM`o!8iceM@# z)Gw0A?f##Jz=~ZXO+65-^J~7ntG>NuUr_f2OG3~bNsZK`B1zdY9n{5t zn6z3Vk;W!)7F%oLP% zr`5N`nXkifgpMT4ndO4l6!L7~r{ovum_qCnUJz81++5QWgK!eG86Exegzkw)l&-w-*`q%nGt9_wIxpLox-ZxPawp0S$%YU*K*s&VeQ5-D? zUe*II7hbQp0wq^(rDLe%9C|wR@}s%(&=Gy;NLkocxN`q!#no4Gi51V*l4s~?XJ4uR zjdJIkdgq%Z_nQ?@SC!*klTQP~#p6##%7J5g;8@9htkN-1at>5nJ*db~=5*PAM)#j7 zInO-x4V7Mc1!aB5bl#QmBW$$6ZA7DK+ki%U!{KCF$|fl74~NOk)^gBcDIE=m z6*_PlUht?|q)HKJ(B!HN;VMYFA<#LBk!FQZR3TGPyp->qM#$kBJHIRBuatxIRJf5& zCz$hBUZsa*L&_-erAOsNA{+*57{OD!ZCfGr?N%+0?p-Rt?-V?0ak$kqPcSU&UDIR%r=&M(s%_t z!+So*m?`zj(QgLySEuz?-eC7t4i8(zFmJ!W7yD^O^c_DUT^INbq|g;$Z zs&naVLPAu7LdcS-APRywAN=Aqkx)ciFi{hJ#97C)dlK{8C0^s$SKFgj`X4MH41MkEvPOuJB zsvaFOW8{=ypHPagV_rp~wMf!qQd^&oTKtnK|1pbo*YukU$fCk*saOTzGu}LqCc0$} zQ-rup8Ac+C)49NO8qFcP&U$xeG1o)LFGdnD5s4-p{lBvqbelf1lL$1_M=>8*Im1i$ zW_6RVU&aPAfu4@;CVh7(KcTcA!r(Iok3bK`8cxT^lvI^m1q~Cnp}BD7!{=*J#Og zfN|xKwchd7-ti}Vx%YtHd!Xz;sQV9=oCjMfmR;L**Y=VtXddC$dJnDk9{Rhh-*o-@ z?XTZ1_fG1)lV$&L-G98~1kiKt(f{ZkE*>p+kLlfGcSD68g?FEN23M{XkCi>+x@Wv( z8)q14AheK@SoF&fD2!r3NQXI|*#ZNSti*3&_c9}D3;-6Sq_eXyj`|8Q zgS?Wa%le(jy8JUC~0mzahSDlmK-);z?pAJ*ih#pENW{;P{1$6CYzxRM`aKLHg^j;)!Z3Y zwG{i^pe5=gsR)^*12~CHm%%%Xyx3BOixKH zQ_cwp0K*99^e+9ReyQ`zJV@C)psCMW|)OzJZW8Z$>bYd^t^8F>?v6s-eep#vTH@>*H^5Fj0m z)BBX#D=DgXz^uoJ0jo!B$BguWX4OE3dEhz?pjrppYt<wh1uBTwyih zSTbs`ww4n_&*;|+<|;P`8ogwPZ^BP0*fcb30T1fqT#uwvW+xhqjmc8 zqG4Iolict$R*8NA0LNWJVET1Qll$FSoX+)3)p%PpozfzJTyVFLCAdSUc^2F`g#`UB z0>Pba46x|e5j-FZ&(~k=0iVCHEe|#WKD}02h`?-EW-njt$_40B-h(iF=2WsEtxDL1 z*=47F#DwVA3ua7cpytk+;h#p{Z&62`gaG1m_D-QG{>|IBv}5UA@G_TsY1^7-Y}GUN z=q;0pRqd9dX1K(oY=R!;rzDt1iBpP2oSO{xU3zvM0s&~oV2P!%u_*IM89uUtI{PJd zyDs4DGzX(3k@+hs@I3@0vi;&`9iMi55_lNE?-Gc}LflQ^rwZFDo4X&3{c`LVp^rm_ zOK@5hw|~vOYt_B0_-@&~Uw7}XxVzWf`&Qli9$hNCkLvED-#TpmiSIZF%!uPMkC&Gn zy`|n`YrUsedr$w%!GC!Fw}=1vaJl!Q-g~j^zoh#wm7JIUFVwdx)Ti{>;Lgvke0t@R zHy*xGb>QK*v@YL~D2)8x&3U`;A6n@vdxmw-aLG2z=6*W3RrxWd!EoZ#g$-o}h=>WB zK_fCyg`5c9Lc-V^KNevzD;ArI_TzMfFyc2eu^(~r(WdqjN4VR zek$Bh$@-~qJ4)71g*#ZXeyYwces`sF%YEtNcyU1Aak#Ygu-OPO)HiqKAz$2FEyvSD_s9WW%Vq}6x zIB30C8Px~Izom@tnz=c?kAH%XXAYWG7NGi5w4?X18QI6shT^ zoR~{vIUest_G)&NfrIe&5-fL-$!L%a&cMwA*~9K4f9%~Y4tD@)bP$4o0Aug+$NZ<9 zGeD5T;qF(jn`Ad7+jFzIHdWp4b@ltI_p0hu)$=d?elLgX+WP;V{U3)o?!Qx_e7TH= zNB_)o-20rw&2SPg*<$<*kGl}F%?mTaynV(#@0fASJ7=6cy|>3)^X?fpdv?S;^WGUB zo}Dvwl555GcA%&Zj$Qo=aK7Y zosu7Utv2pFCk1YEQoY>vvA}Vk;7`A1+9i|-+%2Z(PSn|s^cJMMrLD5>T6?)o#+z+;BT5~zAUoy8 zYLBZ=Qm5>`w!`$5Q*N)mMc=ohgr28h*`+RE^-A5=TAuQiQxfr2pJbcjf;++;AoQ4c zF`T?AD)NFNt8zRUjn9c;ae?w?mBhT5yef-IB9RoMamrWZxSCKz`E5d1qlx&S*Lw~b zYIJ@fCX11VMy^sV zb8^z?OXO-aCec^aEwn$m7f+E`6i?!O_=Y$ejib7&s5%l0tE#xD%Cn0x^y{}oqRF7W z+5hh;Eh~uf$PV$^JrAO$ParSCZ9;5?x3p6r|kL^U%VT{i&gWlOt~o z09+c|Hxu5MdS>6z(9j$EF7J6`Ft~d|UND$yw&D_MFx9y+^NmZFet7xv?l&$6OHM^j zE-LXzc{!u|o@hLyYg}F3``Cv=!Ixi=Q0CTa#m2eD0y%>+8 z<6=0Th~Jo}fd&m@(%^3~Z#eX{H8gqlwdv5<=&gkErn(T0$U~SAi;?91B%53{DH_eeg&U=2 zP&8C&Nr<&XsYeBS@Tb!H$#B1Ob9Fm1{cE=N>~O)>nYXnUJ@uJ=kMNQ!@!=AGB|>9U zzU+8lvQxyfDa;SE@;*^#&T4|I%iPqT<7V{PzFgQfSnIwvsohM&oKSOv}$_>O|14G}2Hu z0hH|FXgpc6UjbBUgsIs|fZq5+k;FW$`w;a&p?Puuf9ekbFk!ZE-jD4^F znZ7kovpMz7_G`s$X3fw|7{ zf}bUd9^X=v(ts)~>yptPnU7{L)86M~4w99d!CaUTq#cqV+hvDI4%j67ZO4pLa!8P; zvI{cD1>Ws8(jKJUNPA6b57ItUn&fJoDeXhrZ%Wr89WbRy=GL>iY>>M4nFgsxs>i(% zveY3tB@T6xtZkBdnM`elymT4)BtKiEwo1Ak=~kmGl9g>nJuZ~773p@VPa;XV#Z-R> z(p#m@+Z?2~Q|^FfK$5W2cqVzbU2@1>5=ZZ{P;YOpb>_Y^6x)5L-)<=p_oW!BWMT9F|C1VQNYuE`xHCOnL)^1ZfjGw-Sk`MI|T_i~)p5 zZwMzx#EDsPT3G}`Ex#rThLeoXsiLIMgR5ci*Kqs><9ozYB3HvoI0BLz*pb8{bPH%W z;rT1kxkZRbCN(7>Bohgy?&z{@5e!$7W6}92eOxJTKpuo1)HB!<9HbVV3&&J6iRmz= zs)$=d9z$C~iA-66rll(|(P~-5@u2}luV9og3Xr9&QoSb%#DJK93Kzd2h8L5GeUXH$ zM4-gbfa=U2T}o6MMp9Ro17PvtD6~3V7Kq}jBqGi#Nuy| zK$S_U3$a0F1PJP6Nx9CEDk1iy{5|w#G|o^4Cxb4f6*H$Kpiw2i<@9dtd7&d%CA8wIrOoC@L!p$P*vMcx`2*CYz6Frs);qbT z&UMb|IKjY@<12`fS3W#b@QRvOTo)X!6AuGD>kiz1=j58joTJdxuQm0ryC}!Y)o;yi z&An3Cx_dQL*m}GW7}Wx!dH3kIYi%9b=L>DSw6XHG5Z{{g^;0IoRY@ zXenBh88gS5iTIq1&5}`rp$D<86&0;sO(t*vl0yljxg}vf8ZQYDj%?X7eJ$uvdV!B zeWwe}XSC)s1^>9_A7AGjj;#zV@oP=pIr~TT_jyg+fB&p@VEpr!v;*f0P0wjf&t;-T zPt(dvcegKXFV=6#yN%zbnd<`QZU3siY30DZlZEVJ0p?(wXn;Oo+SUAY|(e7$*l z@AuhY#=53BWf-+8`w2V?P}@k-)GCdn#Geu7K>|!a|9%FKDO68wOU&FBO8oi%kG7yp z4y^Gh;}RqJRwf=L-o;msAztGHUr*lN^L-8&(-oiE$T{Ibuwx8Oq0vrO;k$MXjCPSof?HJE!Sl%Ya!}GF5^|PZYM$i9%hnng31nZPCV5X(A!@L!&6%cBZ zS(U>^k>6~Km?Z0saiQ$W7BJ1vGJJD=pV~Ie4UMdP!Ib$C zQ8|ywj2(m0g=d{s-Si|9s1Pw|$T6adMMQXrY^=6aph8BL{)A8$0iY~6IoO3pK2!_t zUd`PL{T?)VSl6J{^+U(UGY}fPveya?L9HPOC7-gqoWDKazW@GAq5Wi`?vz${D(^f6 zvKQTT%k|lj?7o6~S8@9eZTsPMhfwGIngjTTz&a;5oCd|27KVA2yt+BXpF{ghUjG(f4+A7N%Yfz{z|Bh8ynE_;JtvCmcesI67fhlX~+>h@fjiQ1z?;=@9b`E z+fczh^t9gA$#L4zWA`P=Z(I6X;aE@ZuLQG14L_>zs!uocSD8dBl?wpn4!%Tfr!Oj6 zj;@Ux9n=&5KuG%lOda%RJF|Yx-T$;624Kik$S?7S_c0Bd;lHGh(;NDzF3w`dVf0Z? z{NpBl3}oxG^_qL&DSeDoW*o}+0|lXS8sCB_dvMcv@HjPyO0Men%z#$W?^)`<5LQ_V z29}b`e^T&uYrgKhy<3-pmDvD9;!ccfKEQ)@4LwfYSp~KQi?2#WE0%SYVN2{nb;gyp zS#hyx?!iXaBw^D+)y|34Txq{K|4r%zXW9Vwrpdp@r;e4mDF$Z_+e~9$_`;bnCR4=$ zl9X(h2n(AE(-tJ8KF5QD5QC(=m`D&r61lyUy9=P?#a4~Xm?&DI1z%%%?+OuXg_TYP zgH6oOC*pA5L}VDo6#6RF%h>Ick^UPM`zAn!`>L_!y{T-=2i=9%pw=4v^|Pys+R(|u zo>SVMQ-#Jct#K^#{F=XMWu)Nm()?ZPoXye5z!G2d_?Nr0jhd$;yQp;!Vei(x|AA+J z-m|}0->~GSf3zKY_MxveZ*Sc=Q!3m?K}&xBO%-U(D1(o1nW+l-X{%Yo70gZL;?XK4 zF`Tp!$8f@!rw^yT`QMa~gE~tG-7cXGwR~MLQPUNRgo!RHwWP9--uhS?rn)NPa#I?A zfnwBC02%IKpy|uN@Pokc{hfut5iM{ebG{-6-P!i6TXXMt8ds^ze3V!*b#0>Wn53U| zZJ9M(vs|=dy~(&&Vi^8Ewnag0xXkzxZz<`q^_l5gQwQq{R;VqtpKQ=%EFnWBGXohU zrM1drzc>&blm|uTdIy&d!a^#T79~?~Htc(Su$z`AN2-C*$&|B(9;vSbU{U!@T=Eys z<`%VGhYP(&wB92Rd`F;q)(w9){U4``ZpgMR`K^N=_GYCI5`}Gp+O|O?HTTfh0)LJl zhN=qqKRN!yFzKtq`Vuk=0%CeJu^#my*UD_GD$2(|AgxQkZt~_O1J;0Zlxd2V^Xl{qKLgJ@he<<}=0%WueUVkIy(Oy;+?WOEg_w$>Ed zR{M3DtcgPE`VO!hGDm7L*)1t^0`8|8}tB|DXdsozUE2wH$8 z@NFxra);Fr>xN;rFDL&Q$m-7lpvO0Gfga7@U+@oT{(;Q7qQ7xDL>Bz>ybfXCV-8@o z;OT(4fX&Z!o-eus%f2t&+a9>LA*{f?U2|{Gv)`JpbtU{^N5QvK^X-Jz4&1c>*2!~) zz=Rf<$h#-LqU5o^mp&i<+qu7+D{MVm2%OUb=ko4zMPFm)#f|HP)(HJQ5&^91BwaI4 zn26tk;2Zd}RGo@&Zix#Eur?#)76XSEZZ@eViZ4!AE9x%>XSGRS}v-$o73AwQ0joQR>%l7v)!;G0c znLu<9bJ`IDMTkk+GsxU_2nmWN5bB^idgqypjF+8QdT0@wi36~^4^XMWptxg{&0q?b zsEKTVgYroY)9v_s5j#YdexqvL^za=&Xi_P^sZ3@7k7| zU>)L=C_aIO$Ts+lGK)v^R^>bIappdTiT*^2%?9|b&^7p&kH->lM5rh5Kn`ZY;6WFa z=0}Fhw;a@eIUf2|kQPK zOG93pE0(b4$#@b9mAs+Qd_r0z+jFTd6#CI(IA**-%cVpF%z=idWS^p7nSOc{z)hGd z0O2)>L`*qARU9U8l)y0pCkUJ(aFzgV(v^QjfQ`}=J-tfcH3Cc#{UJU55rH`Zj8%~} zQXv&vp*5nA=|zbXAp0saf9hSM5ujbW`Z*>{rt)vy<`7^zArNR24y|(lxh<=GA7L+0 zPMe;dX&3xiVBwaFuv;zI&kKFoeh^9M%T2ReEpUE67`=Key?QOZRErt>UbQ^Ej2&5l zO5RZmoOQw&#NP(k=@Xo}zI6__Rg4pEwP2f_@~of^$y_tbU2Us<$#5CQvpNRKwSKN+ z@PfFaQai<3;B*RY_dyffKAU2P#$Eoc*juI#~F+iG*JL+jW_b#zf z5xS}GV8|(p#_P^(W}r32fS6cka4;nxr1d*Kow5b{CWFqBO+iR4w3x`%QbPz4rDrFn zCN4~dUOs#N>}#c_(A30>7hgCVnwma)(a75pdU5PxXkv2w>}wN~=gn_iubzGR^o6Oj zP=9Ws7bZ!iFsm>D53yt7Oi*<^zU3`Nr|*kESHanpYc4qZzSwoN;5_=-nS%4=nzIZ0 z8&`L6$4H6(k4>1+t@IY+5m6ZNuHeB+yjpb* zp;#u#1PcLIC74x;-NHM`WEYa`W#=u(@?W1smN??wD9Dz@{U`_tH9|uxu~+38sQdjv zO;TqlmQ$6Ija5R6i7N3TQ)22*@bXHZQSg_R~>lqdKS%2U6|;q?#FB zOmqe{Owk;Ii5N88N-ag&$gHG<->Q;@X(-B1fh-bjB6y(X+4g%gh@JsQXX&4thH*8^1RW(0(S(9 z7PUHpXUPC~?8~V^x(m>GwS# z9j1Ph7MkITO>OT7e;&-9d3W&MU}hieyi4(=n@e%cvpt*4y?B4D(0B0um3-e3t@~)+ zb2M)|%0{E)qy1h;O)M%Ad2_;lgG#J~f6|En1p%)V&kwT`HpEqny<{_D|1G-`sU`^b z7C*<);9AXF*X-ybyCi$kTYkpc!=vONn5^40*JQL}xK4&^QUtD5#J|<31PALI=S5>sZ_$F)x+=xP&WI9 zBA7)gl~Tczq$W&X7{4$u&%B+f;0Rn6P&*I_3A@gG;yS4n65?0}MP|W#OAn()boiU% z>xsqbMMPBpfb_vhbcb1cs^Q_LrAyJ_hK|C>a1re?6Um~f=m*_Mt9XJ!QVt;@^JiJ@ zpbF_T79^s?Fa=X>tR@_8BRtqCnaC6emAiPU{0U*Wf&*_1ru%}7{Y|DCUkt}6^4u_y zj)-E=uUi&K^Jf-LC4;Aui_A08n5^8Q{GSl0NC#b5NQCB?Qxb~{B>{CKI`Tb2V>ZzM z8yyTW&ck4gOKmJoR0*jvm};1uZ_u{AFxe)-Uopfrk&_B-0vI?>e0CcEbfmqUvth;e z;dH^#qd9s=!nPNEZFzfJ(Shg|K^R)Olxw;dS`{9&?#Z|AK{y0lX|4MT-a*Yfn70oW zef~^;vAHWV_*KXD4@Qbk&lh!_1!regEjW9=*m=0%Jp5U2!8uyhxXH8NdS-3-S*CAy zlD^$JnzwCPM<_-+?e9j4jk~^V-1ney-)g+jII1;{W}Yv);PSIOP=4Jz^S9<#p8v9? z??Fpn?xjM@F0Ex(!8xEg2lCE=V*AeAwO_{nI$r2Mr1c-l_aDYjYd?~SXX0N`oX+mm zONI6$g}S3!-O;@B=(@jL>F)xZH?TB&r*j#n%XC*^$;$?rnVv1egc!Jhzei*qu?iE3 z|9`ZO+9oYUlY;UG84-1ZN)(?%3@32Y6%Lc|RThw&>LQ+}o6|S5AI4~I=(v+Nn2FJCV7M^ZDa+=R7Wylu~ z$I}fx0DR5}l2sa?BtwkAbh~7Viz|oMD#Kfc$2qP)X*?5}Q?1^M+DV&;-to9r#xW@pR&;k%*#jWX;VdK$IvPiLnN&F&#$jQaJ+uL>9DxRtW{x)f;1N3>Gbm0i zE-X;IJn{N!=@9jCcA66MflU86dQ*8#z8IRABp^uMG4EOVEeRC3Zm=H@(^<2<;-#M5 zJ(ua2)1@AEGG?r2uc$wrrYF3lPcD)?#i0Vmxv5aS26FisCusIEmBdsmD~)}Q@2z4~ ztsbwM6)%zp&^XGYmx?6hFnoI$R9T|xMcq$mT4N)64$d}}Pe8MQt&zOtk-SJlt!q*G z;i4yqVmhf*uDe1bxG@q`<8`PK>`WaFYqNu1C4RRqpGoVSOx2I$uo#YCz(GWP5mSDm zu!O^U~2J#=1&AofC%WU^ElfB_~e7 zqokl)`759*e@*x;Y}i6XX8?VYsJdLX)2osTkY-FNQ0^G>K&d5uK_#7>2AD2A)VD)N zsuX3kfmNdXJD?|sQ|<(SlVyPO){|Hckffb>xE1D!wl!bi&cWq_5WInAGF8Ig0FaT2 zF5il)JAH2TW_!pP6 zmy6DAUpjX@aPG+U<oYaPHNdd-LqK?&e(ncdp&K_Fl`1^zOEM+aTgKXJ?-M zAi15}n6*#VyE1&ozwBSazjYTrMFgg|Va0i;tGd;<2vK5_ zOg_Jihj;jQZ0`uoAJlawR z&%T|c^MYUzu^5i^>2u6jw6t{~n~TnuObf^&w`kwSB>i(B^m}OF9eWO08?Rb>BKA@i z_v_`K+_a@_)d;dXcJS(YbHGFGh9CPSesfuZa!3PVSFP+Y*P%3FfUv8q)>$KfPs+Eln)avPuP+(0FoVwq<$ z8QrTzA+xEzjl8I;d6i50h?e{>fc`CvANIO;UbywbO5=|&+`f>xfHOkNudQ=#2i7_O zMZI~KyRt{}o?iS|T*6^J`YP`}QdW@1?wu<*J2YoU-r2$MY0dQ&-H3`?X}s5!Yb&@1 z5b~L4zlT_5{=5G7r0k)0WA|cuEOg%8ONC+0|2TF#mchUE#&R2U^^aGm!Ejj;lT@g$ z%HI+otwIkOJay`nJ`hALg}`4@Kgsy^?YHRZL)ty8`5W%MviwSBVBPK&(8ni&Eu7%X z4CH+s*_R78QL~A8o5=nv2-YO!Q+e8a8$8?%B#;Wz_Ccivv(cNlR!cAexF^m77=e9d7)CO6Q-b#kJ z)FDG48`7G|kg5_<>V1|F7 z$+8_qVXEYtPwBth;>x+~EALKLGW5yl)b}EEuA>yt7kB8&4c+)GB2)SI1ZWBdoot%u zvvvX*rTU80M!@5BSyg^bxPK2&UeTgxOypI1L*O^mwm-)M+D897NyCck&gk+eWOBjN%fRYoO z4GJZ5|0_gG-3d>us#HIhNIJ{)|HUjBK&rAyh~ydRMwT^aksEnD+k>$=m)H)JnSdti65f~GBe-3sCqaLY>UmV=ub zsakO8EgKJq^yYBs(AVtEH|FscAdj5cYxh$f#~G7 z3p{=XG~8-IU?0Dw=xbgX&ocr$*t;dfAKk#^ZywF{-S;C{RMytT(I=G@4+3pqQs zDJfu{Df79lj8>Gt7Sot z&1L%H?Kj=SeUDJ4y%g&>n9DT50LHol<@fsq`iu9)yKaiZ1D{b_1?I(GTjcT{tp_Nx z8)feOj8bT}KWiL}GJWwLl)3K_%Cxu4J`P5i)^X@X`TOIJb3%K6tVR@s!Kf_F$WlU@ zO2(4o(^Jvuv1lSG&yFT%Whoww$C9yVED=pgQ;BIg(Gv&^MqiZKkLY+JI-8h{O-x*m zj>=LjiR>3}iVNdY7wLO+b^=#>^s^CEM8&kVM^RC1Dsbu=Rh>Y!JzA*=JUA7-AVo)~ zr>;mc6*@kZoQ|HGiOC7+Ja0x<^yS#ai&9+s8JL!NE=eztntB?Sr)MOYbu2-3CND`* zBTrW}A)zAH_7UT9&&+i^{4<*9p>|)+e_EbSNmJ*fWX^XsNu^E8IoEUJiDb_Ed~9Y0 zoz3}Pn4!UoP2_y%q_59PQ=?MOeSUUk0$*R6!ku<8S89yIG|Qd9Mb}I$d5LcFF(2UL zJTg5)H%msRCugP;<4I|RUmw+gkG>u1j|W$9^0puepdbQKk6~xraZwz0#hsG-vPaD& zj5*>itXr?-i@R?M*7Z8<$5l_>)j-^9E5U~nf^q*%VYnzB01O3l#V=1!Cx@ov>eQqv z&tr(M#3p7Xf+QMm85w$V?vccJO3Jw?9o3pnZKHD{Ko821ye-U$bHWJD_z@RTi6b@j z6q}<&G{#DcPN6RO+=%rR|rx0svw;-^j@L)Od1aBvoY{!yfGh;oL+R z$-J;tBt$CbPu_fH9)AyTNf^cA)ekyngIIzS45DyZzjaP{PaKkIaM7y@d?hGzRIWtI zI+iN4RZqjX7k3{J7c(zx2E(@xr(G3tHGUSZm4FPq44bT>^Y01r4qPzRlM9X~Mu2A+ zl1SCs>NDLtNY8nZuxc0fvO04fSpqhiveZqzWj&^rzAYq8qj6b3$2jPq_5)mfU%vye z5>x;6X^BsMl&vzhfPqKi(zWOqP)G6-Fr2Ne9Owwf^B$2UV7u`vQWV&V;GTeVcOnYB z1>D*NbQj0v*#x1@*u>S?^#p^p$?2)-`P1Rj=V7*I$1S%YGQ zTk|p-9eouHY9NpGXKm#8s06e~5Z><0Ig(>Jw;Y?gC^0C=o3iqYQgXz29HUW^TN*~LQ(PUgNrBcoXrFqlwd>J)$pirBoi#l#P#L1%{wxJMTV7AX={#p z{b_1}kF_8h+>s9MV08|)hiK7F#U`bZkz8P8WO6z_J3;3~BO_l2Aye<5^!W5BF8fj3 z#B?$-G9ovju$+g}8@Y;p_-16HEy?SA_!@CdZlR==64q>5Fnpb{u28ZY|A}QJ^ZH-b z(U5jDY=$cr#^>9&TpnlT(!kwATLMn(Nr~Vp+H#0=1Ifou{JeRPU+}26O6XRJ=~l^> z$M0-hKJo6^EdeL?q*8R&(IfVRjToDPQM6V+Xi({R@b{2diOxkK?)Y_(9WYqoxXVm9 z;%>l-LMU5LZENY*&kL6+(*bY6w?@c~^_$Mdo@z216iX z;dn8KSV`OoL{O3oKXqw*B5tB$O|-&wDh7rGJcuMuO7U?pG@Sgr=gfsh_~Pbk15jkb zWVtZEI+C0oVFeB4ip;rX0+?JeTUC0VQ5TJ^dJyJf7_JW|cEUL40GZUW*_;!X{@PuP_R2)9*woAW zaF(F?rXk~4g78YJ+6Id{H}E1oKzIn;K%G$AI)74eRc5Q3>ANDkqn*C1)9jOvVyPkq zOwZxuQOIQy`LyFQJ;{!H;tv33@9Ut(xU&l?0b*2-mK>;7?V?Q5XWkGH#HSka z8emzg3FJwb2B9(SYGiGYqqy@$bAYj0K=LCnf%Y~(gbPec6B8sCN3UKQAH8Hnf+}JJ z5|bvRNpOY0f~pupqN;Q)HVKKbD|(gC3$2;4>4}NytAq~?oJvgG;pkWUx}x|Dyv8u# zRVjLPdKUQoQVjg3!L;$}xp?$d6<6Y^<|vW3K2)N$L&ef`iPxlI(ta&>eiOUC}G!(pAXriC%ejYGgv1x|qB)azR1{39@m|=%_XW zKrr~wSS7&LF)sNzDK-{0Xt_4RdB}!~)P&Kk17>i^?>XfVpe|RM*9eZsZBh zU@62RUqi7Eh&==XY7>g;m0-_GQt5p%6C6~6gK5`b{<dkv-nRqF`o1KU)Ir)c64kOpGdP0`1OFW|L~T##M!YWAj#G@-hFk;iL=eR=$4x< zdxU7)isPPNX&f-}vCKSaVT`-Sd6<~}!#d+m7c_Riff)1#{T9wqDwhQ6=~lhzR=s{J zO1Gk>ThT3Fx$}g`MFBe767td>hjZ_ike~2rpivpczkYbsvj@vKDJ%|HwNUwj0oZ>J z7;>NFhZtaI=>xc0NGR|Z5(?r08>1gYxuG~L7Q@By5LlJsTrnXB12?2<2oV~5esqFx z2MO*ySo@5(@N(`wAv2mvR6OTl4V8B|pQ0ph=@6Y{t91$|q$yv|8MNo#Wp=zT#5a4yJ3*EOjh(+`qh0TfY>)`=-*+o2l(nYWqHR zibVsT2uQ%4xChus4Yh|KfwPl|9W(KCgaud!SlhQ}=bop0l-$98;wMPV?A*?@W9Me^ z;WYb@`ZeTiSsr*7>Q|hsUie^~on*nxIdOsgFeMEt7xs7zS~W3f)g}h5`q3CP44uPi zPdL(Kou#?V7nezM%RulZPORuNsRI?9Rv;1c!*W&6F()MT9LTK$p?MyXfK7qMZ zD-NP#Y;1%WI{9nJnxX86kjx7o`KvZOzMGf+tbF_G;*}pd?)lQ;{!F-ELFyS$JOitz z@1Fv+2A&cDu1IOJo)Y=eMhW^s$_hFjRAXGMEHHBtallE$DY-7|%h);&!yf3(UC^6* zp(O_sZ0NImdipmR-XoQfUF5c_bc; z7r~(i8uiO%T1lp7UC@)ehRa#5@I!Kut4~F|RH`go4lDvia2$ysSJn2lskzFKs~Rpp z)F! z6Z-7Da1*lC(HMur&~P-q1L@8<9E66O;`M;_hFrm~d&ZX#QljC^Ua6OK0@QaneO( zd(Xw>Cm7vmEiyGO%q5Hi9AQKU=8{oT>SKsT5*QpVUFG~n_Mka;u{HCkzPL~qb(H$~ z;5d@cV>g*|!HPmpvLh~>6PQLX zubsxt|AaDCWBf1oaWqy$yO*O3GHJ-PMCB{=s0h~s_9Vk|IX63J%Z%X=4U4>m zys6f_)mgBnkC8X=1SX83l%mqxhk=Q`!JEId)N;2+sqJ3#cBj4F8=_0P`485 zPJ7rV)SdN|tb3}~Jh*q)pYb#)o~AVWWc{JrjY}u*4k>kgYyQ5pzmFF%oeA|Qp`Nsd zeL_8^0+j0ZjHg5KbfnqmqtB|)oZnQ;-C)MkqIg=;o|ccAc4d7kWt28V+R|jwlUBz@s%E1B=#tY)f#up$|op@ip8EFO)7kvsAH3{cBzGv}Ob0 z#ad+c)+4i*^rl+k)*j>ZDvHsxc&NE>fnhU}bKVv%3O8VUU&eX<0q?auc$_0=OEDJ1 z{G6H^1vHY0SjEjyw?SnVizdlckW3L1a4UqT;#icsnrOxgvzHsgzvSkoaSU8Nk!Rna zqD&>J1}S(W;SA12QVixPM$}brPIzQ$MUwMnLT!nr!-J;uI=RS2FUr%iGu+n`c9(Iu zSYDkX6O~aGY%)`1>g}GEvG$^EWS-F$i>ggT56O^Ei<1?~a-vYm$}Bmm)mUF-(HE#J zDSmQ%l=TX}07by z1tnmKRSU%>^QW^e-@2<|&DF3pyF8Y0bt|s!w5vNCES-Oz6Ut&86^y`xQ?`JVP9XmY z%C#DE%mm30##ACypVUd2AXtcn&A?+NEjJuf6DU`QS4$n#7Ivt0G$%vam50}XQuNkX zkud7p`>@vCa6a^|D_Lr^0KP5Y%S<(49*rFfHylY_Y=JmxkX966u>LpPC_xvZKwjz@ zY4P0f&Up)yC%eIH&OPTL9q#Q03o&%exvf3d*so$6$>C(iH6>kxX9&vS^o7|Jh`zv> z22(G5QQ?D_$OTQX&go70$e27mX%H6%QH%#`_i<4FCKbvEOevFj&1&RakX>SkBasV{ zTOk8~Vz}Yqil6cgJkC7kR3fDEvfskx)UN!sQJ^&yv_d{Yeny0RMS0X`$2%vN8!~l$ zpeLD7zY^-7kI%;oQJK1T&M)_5>id=Y{?)z8o|Bo-DJ67@a(`6Yu=I^gEod5uPt`_w z^-@EoyiF-@n}2g7v?E>Dz4F%mlgiQOGoc|RG&CRIC~s8CJ2U0oN_qGE^U&g;ZXyuy zf_2$)K@kcv8nNi6|V-fdrZV zy%n*wwfDeYOVgNh5f)O7w~QY#8*L$;TIV#8+WA@2p4d%O&fRF>B)lGMSqrxOmGdhd znO*yoUHdY@{Yr3u+O?mf!?505F5p=Pj~I8t&kuUQ@_i0`xMxmWrswhwGk!5tFBgTR z&ej84=;DHP4O^!HLc(tXd+0Ub9Np%Jft6I;x%cz6CoWtRAFD;NKbd+ye=QgteS{xA zs>rE}cfiQC3?n-r9y*tGEUt|+ecKlc82udvqaU>6+iQ4)xdr3`oRuP#i%e=mj(=%F zIc}{j#8Yz^)y#_Dq0DoXFg!^NA>m2x?(nPlmFmo2ECpl&UTP;T8wV2T%=5VPtfyk# zQ@7@+d*|i5udNLHP`-C9Q-4sYKbY|xQap#!?30f*t&g2t+5e+~4-RJPk1F*?GoE9L z=UAG3HZ7$W&ZZZ&5h%Ldy;Q#BS#~MW{!Cy%2@DVey_XnhcnKictf^OOx}lGD?_kNT z(1N&dnDO$RYo5+*poSjUK@T88rW{fF`CI>X9$ePGiRIX~Kw za8zxP&ZQF2Xt;PD=Qdny!P)*j4ffPx zS|v7)4f;25Vnx3edS$W>Sdk1!N0I_w9zA=csF6FX&ee5G=ho}H*Xp~|J%gG06H5Jwbmdd# z%>0}IuMCb)kYkiwKmrN=B7M2!|AQ{wMUt!1lm|L^GH6L~^*hbS8`?^he+|D<^$+da zS?UMz2k6IUL+8q%($JH3?Z7@7y))*vW)BQuIpWsGthbGi!$@(f5V~Vdj z>kED-^zG2%qS?-72!ZO_DK*xPg9#pT7d5=_=pTE8AsC&AB@)=Mna@*v6~n@mk=oJ$ zh#p%?elK#T1}C=9L#C)Ee~Q#fQyq`qz+q~QY28o*mXpfvqxjC098-`6k14@pY1c8< zmZA1$oBf$Na_3yq)a+!=dv+>$3MbrJ(@8@}i~#rkp;~h8lVmrSiKmoFf|SX=lyjXU zHJ0J>Ez=z~1Edg^f0+`(vP>5u|2aDQE+xc>bByvA>5TXyDC218C#B@2={O@O@^8|; z-=gGqC?TB(AJeA{Nv0rpKU@NK4~#o;`# zQoeHG-uOM}=8?HDkVZg8yrzE8vY_KZ7651W z4T(E1l79%c;l|x@4_t4&aBJ~k8!k5C!Z`*GL4=qHwen-w2k&{mp+85yVBF+xgc8V` zzKE3&LJ38AB^29BfJ2f?+iZ)nbpvjJvco956fuM>x!CjMSQLrZ$4l{Dh9}D{c`PnZ zXcajaRUlWTxi1C>j~wJ&gh)nGwc%Q19SO)3<+xj`jf}lcX@?Q7$n0>1zE)rxGh!A! zv}|i1?KPRp_P~Y7J>2jMmPan0?7%*YDwG%fjPmR)vWz$SZj@(^E8~h+qwJk_2#}hj zX4B}qSX`wRjA%1vq82l;i@{SXe(lJEr`_>7aI+1}%8_awG{tQp1&m>BjH*#QD8>kx z0^zHq$qU>d;?sV9+G(<~^q7n+U!^E4Jw!peG!y7+7KawBFwUP71tLk*9Hhj$7hBzZ z?E_VO3dS53N++sYR@H4BC#PZhOtM%x(K!4bN0SrRyI?hiEhZM{mnCchlw>WmiR#w~ zqbdc^;ktT=Yol;lh)pF~zoCxO_CpuSEh7xe2!c@PgTqx^X3H{--eG21FpsV&ii>J+ z)y9+da!kf%qRb*8#V@kZ7ufu43Z=uI7=otjcQJ7C-$0_-X3171|4k$o!z`oVsS?e! zt?kYr0b30z0rP8wMMB9|)Wc_LXdZtYLq<`UMH0tDeAWo{Sk+p9EHTw`9@b{+GyxPT z@@R~-b8Qnh1X9*^8-L3{D^;X5luD7Qk;IHJ3s_N!5j7=>BWdEUY&pF|2fC6t_0iDu68#5mZ5hcx2a=hffo8ey=zIzryf>w0h^$% z^D^QHjgRr!VKc>ym@I}$^f~zU8i3jZz`BZ}_-rw!9ov-Uyx8YLme5pK>!Nz$E$Se# za@dXuQ(xs`{LbaS(vu1HE5ZJ>tDkp|aDaA@je~XXK2EHV^=R`3&86d)@qbP*`}fBk z%T#LKVOFWyA;5A^uxbOWLu$SEO%Vh3w^7wls^&C7C&Nf%)3a0Y!`Li9E*x*&fZHR% zj{J9!1Fzz$wwwu~h;l z41j<0D$|QFd%f68vE2Jy@XzMPDgPr>GnA@)ejI`5*iQT8CDQ_M$^RJl<-bqKSxU@6 zuTP8>=&P(#VnBXl;knT>e3vHa8PviiX+0EO3q_YpGofZB)Xac?{u>)bmFcS9^{V}A zRr`OGxZnTVS3bD%dx;PC{_gc3U(ZyXRjSTrik?x5o=JP2*~Zv|IymmyhAQqJG>AGQ zBTU_X50o2r33mT1`IX(1EEHxvl$D47zdA zHqmp!xCo7=$v#yWky=@OEbKQ-W@lPgd586h8?HImBbGeCN+#b`&zxwce?~9n5T-9K z-m`fOE|X>>?|`Eb#=|fU97T@Xdv3!K0F=_1GEW-e6W|*76;4c8=Pc248g_(o(9m%N zc9gIqAD%$!u6~cIM|lxXbB4hUhmpxdE^LM7M9u^6K!oYh{9EVIlAQQvPP~$Vj18F7h62$xla$v(wapd;+lH=Co?a8I{=};5aLrM9a z;`Nf2wUU-hNvl%Qx+S>X&q7?M^p*>j=>BX|^Lo>swWd9prhQ7&zO<(?Ti(51zGtm` zPo{jIQoaw!&r6R`daHGzY2gB#^3RITiVL1Uo4kE}Y3EYe>b`XNP$qmx2_ISyKe-lu z^210bd{POY%=k{FeW$j(tOTD>6j43@3Nu1O=!F%W&5TfCyk`YID*!MCsZFzV6CE)Y zt8SUrRvwt6oF8mD;#el6B%fZ_o6_|}p4v8OaE^&onKmt^i@!lrLJBZUNw9qJ{GHd= zE85m7+A+&svp{pW$+nNoFjte_n0d;phuI949MJ3(`rXz*?PJf;!@wU}i#= z`nt0}N5uL)QQky%vHZ$LXhHEjYFy6*OdE_Zrz{q}obgZAyjK3ndXMnK15$#5k*yVq z8kdqQUrq<%inUJ(?n}G&add0uF$vkQ$!Nui9X`prj_JUC!{d^sBPO99%t$n59V0#? zoS{+0OF}`wGVpnZXM`SV7G4+NFPX??p3(x7G{Xf_8e4|aj17WR!W+pqH2>r;Xf`c# zM}E-B820S|7`XSO24EycFwwD#9Gzj6y+V1=Ayq^+d4o($^E-T45V>|7T786rj52Om zjdaFtl2}7(0mg*$e#!p^(*mH0x^>0uP=f~JnFRS3Nf7=SSDHxz^bilP7Exgm5~1} z4L1#jCgU+GV9jKu0@izfgz=XD9iHcl);z2pTxICK8hN%4L|_)m%v`mV1+nC^Z$Ln; z@9Q;RB*zz%kJvIzZ{A^Lcfg$=Vm!2yaovTvV!QC|-fn*z9EF5A_rskh;2w6E3x%nF zaYmYek&zKJHPiveMck`_&=-m6L=F(`oWkldJY&|T8axvxVFr7~N@c+jnMw6qNJgb=WMYq5~}2@GPPVlXBCzzDFf0q+zje zsWiCEXXzkLRF9=9gcB2%@jL&p%i!=+s%HGCD@sZukx9znX_oM4PP}G@W%<`}E!Ck> zA_@uy$~kh04Y$eGt$Q8)4?-R-cxYZommlJvm7R;Py#3m*zP8+-so14d?9y*AfQMkB zOX^Hb%jHys8dGeoFz3OB4Cv6|L8t3Qxu#_ux#cW9;)i;ER!$(YbAnb6F>{&|%+K-L zP`Fr&{x8(-UqJ?FgCjz?@;g)Co_gopa{0S8-+%q?>pwiVTKnF}y^&1uF{SueI&{3C zIOA}`1c^iyW+G+lk*2js6S^JQr9^hkk7qqa-E(?6Zv; zfTo;b0e$)><#DC*tX@85kWny zDqOk2T)(mX0KO@O>!TjA1NHTM&S(9`vz8}GgTs!?LG275Me8h6;*_vG%VePCTO7vd>_tk3DIvX_Oh`~uE3=%F zbaeuW<|?Z0YWN?xl#1kc%}w1Pkpn79*x{bdQu2|v6pk$I@4TMzb}3#sF8kb_8<8d@ z(xpWD=f~&AH~eK31OA;~%J_Q~e=j^y+@0C->N{tZ@^s9J{*X#DJ)$PmF9Z>2Hz%`2AgA4S{odQXu9ogO%&pRS>Im={!1J| z>*n>ptfMpS=tR)^rC_?ag$GV*Si*pYoDJEU(-bWUL6NKpkCx{DD|_!@!Qte7 z4g0buC0=I)ajFpQ2`4LOR$o=R;G=_cdlGTOqbT|AbwCcho?c~N_N3V1tfdB-6Q3Uy zT%{*h+ta&`D2+!c5I>5yCY26n*V3Wo^Y6Z@?7%9)FMCoXcI`-a9{m?y+RNTWd)d3{ zx7;D;ej4rlo8hPuZsWgwf-k&4I?SfKz3Fhf;_cY-Yx!7ap0pGy&>L|cBzLcan}`BX zL*8Epp%*RLNrkg}0l^oz`Ab_tEg#Fwlb(VFmQjI7)B@?JULgI@3e@tk%seSr;4jkU zhLARiI9dJ5{m~Czqp`qjSd#{Vt%MD>_)iXd5$EA_E#MO0Mxq0s1o=W*Av ziJ8!HV=uaSSGdgpWbKvp*EdD;9`QK~J79!-oe2&5tSf-&*L zDv^Qj?zW&MHl3(+!}VMy~D7IlmZ?e+6aBe;P^71Gh1}OpP&s z@*namvts}pn`t|@t59F^oW~$zy8zX>VvEHu0b!Zk%m}<>eZa`CK~5x@wNdsHFG0(j ziP4%M?Zs7@k~$iZ08XHYZV<42zpNuu)}@qn&G%%BBlGPW#pR1%di$%t`ql5gcIUNB zakEm~Oun|*f{Nsqn~6%U0|eGruI-dTI>uD7jtyj zmC2Wq11_RQ3Gu%Ic-RKJ`0Onf<4g{iid%Pbql+Ttg$ntq{FW=EydvRZAMDw3}5#%NI1_o)6lMd*^NIwI$A(Wi3 zb8$k1H}9}$KIB($WqKT;daw%vVSS|0H$j>3a);4ag1@?l8}=?qF`4N{ zbcRhtXc&3xi4I0BeNtYHW<;N9-!mYIQIGtwd?RRPU3|ST-A!IVMtON zm|)nGzwwfHJf^;vh-zoHB9z6-rVB*vaxM%?Jm;AlkH_(5NEb|t6JJmOc`aBW-b&?$ zG8MIWN7QQEBUI0O>I;Y>v?9tVDMw-x`pI&fsd~tm9SGCc7^YwKf&8Hi&xnDjFl#p)|hor~>VH zq{_(pkyjawN~+vi86U&fs5-(Pk69hPXf>}o8h?A=Xr^M8zTu^H1(gvrlw6!=TA%0X zCrzLF!V9WUpJlsXUrKi=Ig6?ie;-fH>wjAguhTVIdP5sFeTU%!(UFfI0WiP_q9Zqi0Fd$=8EdYOH03`R6@!{5M(1U<*nB9RzAE) zy&?{O)U#g@8GBS@Om9ow4jHp|ral5P4zMyj1(9)3dyDGGQxuG82#A8Paq4?OU6v~< zE6uzhSoAh>Csh#@bMj$0k7Md2#02uVU~$x-c&E^ECVvD(r!ZrFLQG$lfRPEwK~^BV z&zue7(pBmD2#YKPazNQ+EAse)QI||&l3a_V5LT+Pj)#txCTX?opeL&d`lCobC+?B! z>FI))r_5FW?C3=GE3OsL6(+&J7(&vHQ`^xCVMaX~*w z-f`mp!6p(Ln{J7xZez}2Gsgx2Q%Q)v?m`jjVar`^X5wM7p&Hp?Wy4XTZrMT9aIvsL z$E=W<`&`NvSYlFL&#USOj+?;6>5P!(@6*P_d@>h9mb{8~a+3z5W)b>Hbcq>q!6<>h zgH^c=BaxudP+i~4M`C!3q$Y4`&MKYAF;757rVZ@n2fH7HWy0O1xS>rIomJUjW!hEA zua9Kh-HIFHqv-6~2-Yk?EHHc+H++$F+2DHFnYFSr8=k50_qyeO+ zKRAp;53B<>m0ONwT**HQyQwg47%r6_u9#HydpU7jrj@Je4|u>FZs6U`*ddaHnPLks z{l#WcBPxULsd9&qPBY5*W4Nl+4)V`R|5E?29sG+2mj}OpvrGq|I7 zvEk9XXso6=k=eByHM_P${--FD`NL>k_E7-rL=xM%x(Krpy@*C1r{{lxdj13+N6&!? zD~}mTwzi2P0@bWn^{rL)tyX8M4lDF|;=}56CFwHFnc1jV288q|Y|c;M#2!k?LK!XU zE-NBft(`N6Qi5X=qc=oq-ct?y%z!&!jDl(f|8-@z;`vmJTE>+mBKRwz0!^XDexdv#hSlG@bMO~EHU2x374+LqcbRkGfrHgL+1P??2hw};6xFhQoC)O%XP!KTR zw|y#*Hx1wf5XM_Wo$}{?6YX`(TXrrXHt?j-$%p6F3V}JP`L2TOJK&vOD`3 z0HM-SsQ{59{KOYiK@)DI~218e?)wD#HXS1w*(?p(Q^@gGwB zhtiG$udpVJN|^q^H6&J9#3nyrg85+%yJux3AY~Z29C7Ebn2g012?RA-7zbBsC!Dnz zomUm9%sxomfTjfLF>HfFa9LtXE4Davy%w~FjAR1{HIbmU0IeS3HOMEBZ-tqDw-{qSYv|j5g#S);xADncVtPdzn{%zNMUS64&%=CUH~T2@7E)@tZLb ziabI?1(#3_|9(~qCnZ( z>oSiDwoG>ybtmVOj$*Q3n3y~Q^?mU8v+ zv0Nn^4?EyXRd5`vBR=P$d9YrCCP1HGK$~7q;d!{>gOu;fn~l%g4mKO@?$p0jG51W2 z-V(!uxF8M&?5W&h%;G_+H(+>}k*AZ>-0A3?Bt)IX=JLC`h|-h>T~36&?e z+GiMz&v`C_>f1EkQ4}h z>~)qFkuXs-e-Xk2+W6{HaQSQLs)5zPOlZFn+K<<<6jv-J7e>;ax{pHP+pngpx>jQM z#uP+x4xLa!Cq8^O9eN?{DZmXAlF^w%J2z}4Ryy0I6FDFzfCWF`B5JZ+7^JuGM)V=D z#IF);y@;oCLDTv~MSYb%j5+sP~5qIxWNh7#p}B50E|=gyg$TTN_QVX%Rs^7&<0`(he|H z0Ikq4p*22`-FF$HWThrbiVP5dz?kgFuviuaTAd!x07z7&HE)#`WyEB((51x+pqRt~ zl);qNBw#t!L=_gHz4!Wl0uX^mQDRXuqFz;Vl=mH1HvfSHN3E0Ol~Nbg#tHxuc{4l>X%k^zx`e6h+^8SuQt1%<|# zAxtu!fCeGe4nwFf9B2p5s0ej3=R~j*K>Pay@$6sUcbnb5G$Uk@IO~VlJCdl}zsg z%9XNraGCQ%O+X||9;4yCNXaE6)`gy{&UOr7jVm-{V*CVzN z+n+mCQfrMULsJzsQz@ccJu(H<;Zj6^OAj=22Ft=F0VYw943P`*LgDpH?zcJ4vZ&WM zjV#ErkOmYGMV|YN-^hjH@D!6|H5yT_ggfe+ECMu5L}ob~KS>W#vYTLvwmtxB6bXUw z?K;{7*R~OC&iX>P1B$OETTy#wSgGjB7MI<=e0w@uTCJ3}Zs9!`ZBXVEwb9Ek5cu2W zZUan~9om4qeMQEzTk-5(A^(c90We5q1M}URDCTR+JMWBu+u!i-%zDbQWz9-iUlwn< zedhKH*>I&2Zb7kuR?-o*l8y*rz+BE&PR2}gDo|JG_k$BiY%J<$a#Q)f>fopT97Y0@ zjOnz^G5lTlUKhqr4SxO@NRLOtY;vv{%HY(e$4b&E9Yv>Mc6vm+S8*FBx1SLN%f>EQ zJO950yTkyMZCkYh@W02@%aj}gBrsb&_L;N_<-7Cvp%1+o{~5)9Cha)$5RQm+|K!ZW z97~`Ix2TeB$mp@TqW{bgiDi`k3vjBz@M6J2*-xgG0NKo4JK zFi!*)n&Mq&o){+7C@NRH%O)0zx9?#r)CGbi@=0L|ca4`}j8q;EwR#}qKce_ya0UGmCN}4I@q=ZRaL@Sw6jkC5)nZWc2 zB#OxbB_btE%3{1NBiY2qF%rqp44*vmhu||@iL+0AM<=J_vlG&Bnc`04BP9MM^e0dc zh~ieAPXrH6NjWHBT8QvZRtV0szpU`(^y7Xuh0e6~lNB1$)=yUGOu0NIrwHPopDe!eF|=K46Z8F-9fw71(8^D!bX@(!xRnx} z;#1=Cg?F)upUzgX3xU21q(wq;+2Zl#Csv2jWl(E9H6PgW3_3(?bts%J%6H!SnBDut zoa2luVb6IVI>M|Lf4*rntda-fr#J{!dD?bSv>tQ zUH!zkb-0Rh<|j1RxcZ54Yp+MFUA|0})ULkHzU)b)PCUG{_wJ!B0VgX}>}wUip%i9rSH|}`FDXd;1G{xy+sSJEe)kZ9g4Se%Sl)IjjVG&W9hy#O50g>$(~g2Lfy8)*`>aHs~0{P|1k3VHR=7&D1FavIqCA2ke9v) z_&Y7GOcK1E7C&_Sewd#XNFzeW3u3zcu+lyRv8%%aUbn-u<*gLy&5)F4Arw35WE1a) zy}spUcc4~oR6?z4zJLI=_sCYzc!)o$ruz$4;VvWOf`nCo@6AvH70Ir{rzNy-9!g=y z-EcazTXDB-`L+D)A)f3NI(xItyR)qwTV9Vy!eU-pxn1R8&CA5rY_NLEiEpFwR&uDJ z?w~@Rhh^l+K*7?y<*aT>S!d`(9W8c8A1(Gv8!h%j$V&@X;Vz}oLc%J*x7J{~UTCzm M{OloH!tnq90^Rqj%m4rY diff --git a/struct2tensor/expression_impl/__pycache__/__init__.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index e13e04711aa005f11fedf257b3c28965899351d0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1314 zcmZvb%Wl*#6o#F-WhQfa z6$;(5VmDHEm5?|#MNOs?{qyTPImh;yljK{!?;*GjSAS*?T!em^;Qv*-g!3N&Ur~%O zik-N|owSB)j^R3R9oH>4fE&1Bxe46FP0KCd7H(N~fnDrcZUeV*+j0lEgFBWzU=Mqi zyTD!CwcG>l;hyC_a3A+A4}b@FV0j2U#6!y?;1M2K9s`f@*zyE;f+xUR@6q;+)1P~3 zCZzU5&iz!xOHP%~7Ct@BBvp!u1z~B%r=IsAq2&UdFy+q}r@l(WlE?l5^}{HFH8l1) zn^RxIf|=l49I?fxUC-O!-}hiIbC}X{8R;EO7fMJ!RKD8L-b{*gEpu7E){N&t`E>2( zG|V!7LR6H{G_+^^HqUR{IqQ`(5O)rw2;(SJ8iH<2Go295Qf9(2{3A^zx!N3VHrXPk z$EG+9GXirlT|;=Q6cQ~;8P8*x1C_7T9-m4vsE928l>%kj3dz%J3w3 zOVwP9EI1N!t}+-#Q0Z+gx_S=wH4F%03#JK~W+xZzm;c_^SMn|#)QcZZo7 znkGZJjG@3pAfRlKz(ivtd}v(-_Ct^-KjaVSOMzWt04@d)Dxko<38qOZi$n5 zS;+HCJiY^Yp%7RK6vQR55L^oK)F`pAll%};7?=4>6mkr5HPsF8GxfwKV9!uis(v|7N1grRk(hY4RV{B;YyiY!73S%{=OWB-~mnAim+0b+lNohmMD{0Jo zZTfj+SrHedAuu$BOBK;wUej&#e@7rrxZQ#_GTgf|I(Pt=`NTtL(kD+ZS~yvzp7x zSctT|#n?%g)>T--GK@r1X!*JKFa;AkYbctOJ|TRgSQC<3RHO{dYC~7BscJzr6zr3x zrsWKDI3X!UETe!~+SXm2jIb7|sBG35q%Ljd@_B45?UkZwiky_*$Vyx4hP0{W3`I)E zGB60WP&a0Rb&~`(Nk)|wtQ?yN!&J0Ft_bDS&0240YBokaP~S%0z-B#Er4_|CTSt$} z>CB@EtPfJ5CuO%@A`Lwx(~(s$Cp*pN@}w{_dS*^`t8Y&*CRW%=(S}I3q+~KlZ_YIA z-KnoCn&R|AJ2eVtY?5t?D~ZPDOtkS#vTr`)$K(gpDk0Ehl4hN$LBVx~638{jM7T z8@J{4i2@b;FS*#oM8JyP)6{LH_&_nN$lL4W@X~oJ@<93VhEmKZ3EpalYBL*oEIy^O zejSam+8lTjX%2jcp(XObL}1I7T3?>AL?X6Ac8e9#6rA#6MsU15eZjwT0sMXSdlVmY z2EwKLtaBH|C$J-Y z3QnJS+<5M3rxCBC@Ntm2Pr>aokK+UGxBOxvq!A9VLJwGcvLbiX^)0BjV1*sGW_6^B z6jSoKZ6)QlwWnDI+J7*2IAD&<{Z1}NUooSvl!RBl!VH|&fx>4Vgrk^w%wUT?=61M4G@N|o zhK<+G?r=5pS@X7+nSfw~YaDh27|jV2ICUPuZ(!Pg?eJ?JY617D$AO)|8W}ZvAcczr zxkpS2Z3u1Nh{HNb+%aU9{#keqC@=7UdpeETA}$@9P23Ux;%wfCQ)UytWj61yu-Sq; z!M`}0H{z7pf~U=9bh|6#I9WWbJA}3NdVMzTo1gE8K-oZ$!So)y!CUD^ilhLqNY3d) z9LZ-VGini_uZZ{@i%ArsS=_Lp2*N31H#tIrI4gN|8DYjod5EMql#JNdF6C8*43l8C zkk?g^gNUFRAOeL323%YbbT;WCfW|yyMZkz~8p&YKBVi~!CX;YN0zX8UO-0&707R&o zpn68#0O$i2Qwr+_riIf#w^8(*%{j)+S@Q&;BMqw~9Oo0<7G*(O(dWTKawVL1E5x>- z%DLqL$5#k4M8k(8+o5=}dG1D6|aUlT3~Y40ue zgLHmFd0o>~gee>N!=IBIVs2RLq1rGN5)~sTbUM+rdP&mGP$Nmeir1-k8TPizRfoAy z^Kr2IDA>KPmV;wvaIDG+f%BE-F0=WZ**p&S7(CC4J^p#9_n`gYUO6;shDJ;5QwcUX zqgTtpvu5yYm1_x%RQk@EeUnu#5Nu_|UUMbdV@3xKHb42`C^}h+PF5nVW@P9%GIkUh zJA7P@Ts0$COOdOUNXuSmFH{Y3v1B#Gof$HF;GjJ=u}ZY_>Dhz(pDZ3l$4k+1#-;x_ zGI|slJ$$nqnKC0&rN~qrm&@P+VR|puacQbVyG{7^l~0OC(TP%Yg0UGqj$Am3TsX{^ zBeP~?wiKDIV>9`u5Et(*bX@D&G-E@H_k|K!Wpo695VRkhq_C8Y2&*^!$50x!f2#TH(#T zNKuOuazGCJOt^&^;lIS~e|bkKsP6$kl7FBpBB6Ni3IvCd_KxW#2ZAc^=X3-2NE>e1 z*EVul%;UV|nCFV?g!pa8N3QT4hSMBxkij()KGN%l*`LL_*b!Vzgl+H>zGshi&f6V~ z(d?5t=&%h4e-I<$Mf zB7{Dg{BUwVTo!swp|>RTQp?nbQ~P6Oq0bchNH82#wo58vJYUO6~w21f~zj8!5{ zW~ASYoIAWi|4NYym1y(sWoA!Si>l>Ta)7}preGfup-9c_>i`uX5&jV0C!a=p(0XLy z38-I=M*P6G$gcq0`p^mBrw_0;;G9M{nh0#)UU1SS&VG{f_UB6^XRjn&_bGhfqKcrw zw0|ziw)MopA{X}+dBu;p6@ajJ#m6{0|13NW0s(W{1UeFeB_yYrzpUTZZ(=v|Fvg5p zSiik}E@h)jN3!Hi7BKbOsJJvcV5Of&0T;^g9i6q}fZr!4P~XHhEfLiGCZ{Ip`;kQ@`tqJ^ zXphVCa_~vhQcMnEd>y~q_^cp@k;RM45l}SMGL1C~WFSu&i)>`eTB|3CZ+tc)$HBJ^ zGq+<7%1m}3e^@62>k(NlvQutVy4E~|h@R8dB+6~z)bkShI?grOogC#EaO|z0p^id8 z?owRp+vDEnm-^w8JDG)TUtOrrLc-a4$v}c|WKl>>urwlb(l(AsS`pXBSS`%SxDHRZ zTAVANW2hO?)f>LOd(cz7{F3*_H@_D-tdEKw3wR*!7KFMN%Wws*yTN?U`}43A&33qy z7o~7p;!-3FdsvP(u>Ie62vuUqEDv_#TFasi5U~d@CCOERjeOm=x^$QgFC{z9c^b0G z9an{j;nSF1Cplkbokq8r2N473-KfwMUvX3b%}R@ZyD!6Ol@ zq;x}BPuXY@5zq?i8|#W@MN+9;F=wPwRy?PtU=O4f*%FN{^!^obh^E}NTCH}7@z$0I zS|VLav>6a;vsAE2qbbD2w63R8I_&~{mb7cb&g*#H?mww+SBjsZU;jH4e@67|=Yp}H z-hXocSnN6yyDE`3Gct(sX>;Ywh04$ca$Qj_*!?9J42Hg@qRNUUu4j1n?vusJKw|f< zDfX7whdC;Zesb^y3bH#EA?-kcyBZqoX9ZYBBR`PfwO8=$lYV8gccTrfa?pDJRC^748pR?Ymmt5N( zIJrcx^Vi^|(=`KeD8xm(_g5;tV`Z^>H&u!Em&ATMXppe(;|WgZ6hA!?j1(3hGFi1( z(N?eZGwwvl{-t2Gc1qRQ+D@DP06H7uz(0jaf&kkcOljXxt#X$fzYMNi^gy)A{YuZ6dLW+!fcLhuYL z_QI8D+ta~={-fwR+ykv-CW}%eCx3ZVh{c&|d=zFzeTFIf-Ir&I1$ofI-8^0Cv@b2`)W- zH>Nhpnf|mJm#OcV2z`3Tpp1_5G}?%}_fadLD~4?V?T?D;W)VliJ;ypbZ=f(V`y%^& zaI*qwS+=71Z3n-IU@UGGWFqvF1vzhx;;2T)(rQC6Rx>#Y?Y3PtICkytLG?2Cb3%?k zfKY!Q#V+@}xoyAe7sJQxSB~1Rl-p;`_StgtoY_3Ld#}1e9(*!mc&73ONjvjF67_L>ife6}EqG(zW~aCN8ULULy~${S6g#?XkoM z09hJ4v+qzNxf`ZitJ#4>86}@jrGO}Iql*Q#S8ZfOE}dcY52@o z#{Ml)4Nz_R`WOA?jYV_%ZT4Q}&{$V=g#$x}eb44TzhRDCt#Ww#4h1tpwp;k^FiQV< zgW=sok48T_;ao0ftPmhH(pa_)Fasl_bF<0LFI~*Cn0Yb-vMq10_Ma0{?tctPGAEyW zZ`!un6j@6f6}?n2BW5YxOY}O8!U`|EcIPcf#>BLN6x|P*wS5!0X&;%}+9R})ztw+< z5-9?nuMUpzVwI!f9)87%@5bm~g_|y&{&~S&F8Mzd?tIDrsc^$3|EI!@l>DC;+-%AJ zsfI%Q&HZij%tVz#tVOcj`i2E!j>+lx&N-Th_&K9N*)M>?qkd*xjgZD4H{}sijEmnUNK1 zDO9`bVuNiS{E#Ak$b%j>t_`QHAGSpbHmC)L@=H=Q-|EPU2=b ziI+TCe!(;2S@6zy7lavM!8haMG2NT>F9c=+lox;x&IBp%13okpqP!pY@JyKU0pKGu z5y}U%(S^F1I?9K#^)n3~?hGe|Z*WpX4$e241-k9!xO?!=u9+B>ih@)fNHv-1RZ`7V zsve{owvcL(8l6w2m>iyOH8o02aMxztHM{OwWS=vqGvA!2!_36Zg{9UT!c6r>l$1^>Y9@C@ zOo@vr^{SZ30jCmK9twxwyP8qNMUuakkz_?&*gWT4M$SrNDwoe)U&t>h+3RAjdNrkr zkbOy!B~Tg$H4AxFPNs5F(t2-VK!jo}XW+f5oG5>|n9Za!DrCQsxhChtTxtP}adByJ zF|Wwt_56}}HFZrkWQj=RMItM54$FTi9G;z>4flr&;r>BIR+LOW*C&c%ucD@i0ws_# z>69vC9vQx+L>L~5Vu5}256bXDwXZ7iS}NN&n8~RVW8tu&^N{#65Teq+u2%7o;gGn0 z1X3BdKfL-aGLDXbIMxDMUd|acvKnj1;(DQ_$N=rHp?9FudjBkLlrBCHhZo2H92z!G z%NPeHcP*Y^YA?&_{DQ2w+G_Ayw3gLgqUdfhcwjdsyXhYs9L!xan~yfOEBg!Emi9&N zn$e!32<-m-kmzjltz@>{7x<3R5vN)^g(|W85`mF%ewEqrNykA;PqHKl}#^YVZbHxR7`=yFm@jsp>S7k-z8I$1StYeMEDMX zR8GkgU}K6(meT5|N)terGM%=p5ik+s=GlJ(^DD6LGB?A6jrU0Wj92o=!o1HEm%Prq zIpHI2j+cZR(HXzwlLO9Y4twX3{9s3dAx>(Pf*=)|rz;KWw56obX8EPC9GN#)AY2FP zKtsJ*B8LW_Q!`kv=ndaY1C*g(sbN5cp|k}L%KgtGFLJ=J)mSRXTgf3!9rS>036;wy94B}`A<0&mQM4!4a`0B zhFf<$u@bogKkqT`fVXtthSYo36CUNPXKM{=$jrmRQhXaU-?Oy;f-mp|E@^dU!2@@G zQ(oer`kv{P&XZtG^Et4imdOxPtjdXi?o*Z)Wuga@$xJSzCX;$3qa;nUs0-@#MOhF3 zFqK`B-y|eY^gz;Rq+rr$rMrSIC~|g=G{Bn)uIPFwnM`L>ijqt!*laM}*S>jw_$pYN z;YDd-iICLw;Zw5m6E(j$yqqULQ5I8adDu0otL_}rBD1iV9k!}6w0NB~!#h@X*XrgW zvsx(-QO*PT!y5PRLg?pbZk%}}v^*4AZvIef-FK($i=8E5TocBN!gx7QcjITpKDLSS_$BO>3a&yPpsT&u{Ej!`stu^?22I@H7Ph@q84Yx@Wkj+IsNQx(R+CD8iKOME-3-b~f6WTl z7S4d8XE7jaTshkK@r7GGpA8lxy+xsy?1t2781nQ5iXMf<9ahj3S}{cr7Fe4k*5{?016Bhv$K5On8U4XM`j zcD6<#Pzb@`u*&52bzv%dr&o5exhLlGDoo)Ea%yfmFUjxGW#F5=#(YF)V@XDDi;11Q zn#oGGTq3H6GfE~0o)9cf1mCCEyHyc%i1mi5j08=#9)ndalbp|Iaxfaei;`i;ET&*- zQt@8?Q1_^Fy8jBSSzuNCv?vN`ft-mj9ob|L@MJHNeMowN=uuX$Bo)`|jEWeyyR5k` zOG!0vnjPbH4fL&;&fcYmbMmsmE7%E!uK{0dI)$u7c;%IvGst?Lh1|+N0$JlKQEtyb zN$6U;bn}hJk=C2p&(CN*hf0yxwaDugj`vMaV=-2Wj%(5JqJJD^7C!%xw&!pu@`e_9 z17(i$ROo0adQ6KREBcRp-N=PHwP639x3q!N<$;6Rz?)CKeCQ1SoI?t{&tLR)tsni` z&o#At(sApswsYuyLn(Gliyd2=dK_ueB3;G!kq6;&e6JQCg$&U#$QvEQ*NxGc?J`Sf zROoEQ!46R$59Y@+eflm(;#gIO&1p9q0jxk076h35;8&~HR|d9hvMT8tp!FF}Ca^k@ z#*fBniF5&tSlY4z@EGSWE!O}r?uE6$J7>CkCSBx+2pJL{(hpw*JqR$RXg~+ia!Xm;DPGwj2zOX7(?pLS^SC};pdsX_~u8~q?zZThF6!y~!u|l-q5{e-TETQ3K zav?7*Ws#32llV5{Ng$b&^66xf1VN!5w*3*hh7mNUq#FsHmSKC=5^P32z`RR^V2kO2 zcbG>=4&t3~IiJsx!HR(Wf=^SA3Z0=UR8umuf8sDgjIha*bVDLvJO7H({=+^VYu4hP7B9 zQxfJTk7yHTZH-hdBO|;wwyt6UV%5aYdq>x&w4T@S!RVI~{n`d$osOd|4gc0<8t!NB zKudnFPIvSFRi_(ZispH=_4ERFyw2&K+m8L5AjQB*Zv}}qXAXyiM`{9zcIw9OfW@yx zZ%e)_n^sA-Cr+vbjG28)kqlx9GAP5bfQZV{WyCY~0hB>kErtci3Lp!jNFfefIAvJ^ zpaKHq6d)GjGQ{6v%FK!@t3%^}2umI0Pyj|?UbO<6()lI8KrYLo1hF5*8l)jmFK0;A zd+EDGHIK4OixPqnv?!+m3Xwldfv+nM7yt;q2W=+{7AY)qh=#oO634jYsSGrMYPU)+ z$>&uAvTQf6 z*lWHHaF;y5H!Ls%Km!^hhoa!`{z)}L?dHZ*Z1OGJ2Zii4i+X2{E=ISL@y#%;oZUz#63#Oll~Z5_F=TL-6?4|^W(o>I|IQlyJ3Azv%Rv{zOankt zhWg;xpJiU+teBq@XSY|DS#jBzh@fM3%hsI*AW>rPG0T^qqq%C@o$6Xwz3GGrIc(8Q zfb5ClaaARm%S$TaaM=E=j)%}z7`$`ND_2dslG->{pzwgQS}-V@A-u5ZnATL*v|(WW z80I0KQ=y4-b}r~8T4cyV0}RpjL58S8);O$yjNi0?n>7OHHp*7J7pp5mtFX^ubN3lW zbe~zrs!#<>O5q6I&x+_ma;%{B*RW1~mq-pEH|Utr>kHl;z#*=)HC6+zIAD7%5XNM1 z8A!r{vX}&5ej%kM0tEL!n3ISLEWr&7@-7mD4gi&5^-KuN|D{mO+xVDzeOM|O32V&p zLuYp7)!I#O0VaPBLdrQH)qqJkSg!>;fqon*@9x%iAFKp8U+Ys&@cEx3sZbK;S~@9Q z!cdd73u_mc--L*WD=0xP-3$NEUW}EP?mk6JE<&E_xP|)8D}$Rt1NkjDelLfCknj{&@ZE+I-GhI2{es^D`M`#wmpd|2P>+9u9t)oU z+yN&6y#VjPaFFpY#yg;n=2eC3U`RtRcUj1|iUZEL9Jq+j#yVbrkyx4yjKso@-Bau$ z=R#Jjwt-xnjW@%);EBE0XhxBBEQ+LoSN*%-FWH8V3_|G{{LyN|AF14R3j(QkqYz{m zGhj4JMv&}BGK%B{NC-IqLKdJx5e+hq`6iG|0;z>*&;W2n++^5WPV-lo{Xc?#RwX~0H~ejVnbZ4C5yy%agDMGhB* zDfoPuEeTT^u#v-GM=6*wd_SsnZLzUfz??_Ke^E3zRsEF9nJ+XL&k}T_y*|Ff};CHp)xkjeo`LI6*CVW3&`G zphXT8g#)zkdYF0&6x*m>2gDEDAose%O29+oDpYM>ta=KbHQ1Aad23lCj>An1w*|p% zjFb#2&&aAmeStJWnhv*vMoi#KWwRoE1h@z~iQpsE^xZ_mjY?$|jJv2WQd?3Pu!BI@ zX$JkYF>nW&7Q~1q9YGvbW(S}U&9SJG6&W1bOb(FI3=I}x6P#*c-Wm90RLHChxnj74 zB6}(AfkQL{qDfdmbnTpZ-lba+RyMYaZpMolY!Frq`NNE&Lf!J9fGlSq#GTIPU~?Va z#p#5X`4k~sFR47ifd-OhCce1bG?4+K{wzCULU{khYrLFjOV?Z=ggO z0|NGG7uVML$)!iF{SRCF?@X0iC$-kewdr!8wHVl4u8V#2-s4Eu`q2HJQe;w#OcsU7 zjkDZxEKuck#;1F|G}Z-Ub4HK-B$r>#CGEMOY%mmQc+aL2Z$tLyApH!1b_7QoK0f>? zl6V+Nlp_6Fq`xTiGs7Xm^Sd471bmT`NKPSn6A3Pg%v9wH_0sCV{a0y4LOhG0uU!~NX-tCzrqK92@?~1MDwrW>GFaP zh~zXM^mZEPA#mmPF0Fl_;sv(CZAMynZ?79Y>_te&#+l%%csS7`($!-9VCRqp~ zy>mf`Mv8$U(mSe`X*cMREpfy#3r8GdI3gnO2ya@%AM+jrl>Rq~iWj+&gI{)PuT5(s zZ`1n<2Z;`{m-h~=cib8K;x(;n9JLH=gRl{U{eoR*UFgu@_5ZjeYX|N-#w9y8g=AL# z&Uwbbf!yI~D`VxL#-5@jj^}m=Pc1xxcc!&B7aa3Il_Mxd^9|AbGD$50)&ffvED#x3 zFqsKQ1XI}zHkk$Z(}H4VetQ8*6$z; zwoeVPY5LM(FzJGlZCHZ3A#r77^J>)&f#C!vM|w~nJq)n@h<5-ozFfrn61)DmzSGO? zEmmpVxa+^2DAy>Yk?st!uq7zzlqUfL4-ln+Uh$HK?k*O`?NIw^C)*|t8u%{f1 zYQY`l_Pv#WH`@D@1M(b6h4cE@0hgBz|MtJQWvU`T5$DwR9syp(jy(osV+80p9&{4J zC3~x2$m~pOFT5K&)fEJemm(8dWTGfcFnfhwzbpE)KvV=FtM=?2N1WEJSqfs-RUHX# zjQEHM!Vq#$Uz6f}7-hm%zZHYHYCH=MR}!0dQdOjZo83wr^1xv*h$DJ_Gw>*W@L?RB zvQqq%7C-esDn-s~k@H32JgvTK_+X82{y%%ExPn{;6&vg~dT!2xUDZ9eA&Th3VJ4v} z3N(nw$w;c5b~;j!`~m)zHOS_2+68plkwk;mjz4hbjz6&Iv)3lzvu|(5zX|<#ATEy` z6WtCwrda+u-niU#mSA}6c!L@pq<2eCebVyOCz+=X=o9aC62HfLfj3-r$~t{@OsH@h zXz-8s*Re}<0bM-D`N?|RQnN!%LcIk6B+V-b3>||XLr3BYyly^meY+#tb&%b76{=Rq zRiVIbZ`Q8*;GCPa5|V(5r82IZU&W}>xK{lhz@rPk%{8x^qOR{hru^qHBb;X zX$7x4ADrx~vx{P%0xWQaU?Gs|SsVj&mmXm0@e!OEGuwU~ALZ*`((aT)50o zBB+PaFJn761nr#OjtOuQGcD(&6dbL~lcda!-@;+6ES&per+0Ah5M0gm5F8c<4eC_| z&e<(uas!e^BtJ!gs8Ikaq{o;O;y6#L+sEu2smZxiTFsN|dJsn>hA51~H{r=wWhtxH zLZtNu`)3^#10t(Xh#?q&%0B}GzC(~}+VLnBe;A9GVm(@{=TU6vVQlEmaw#^c#U|GV zp0ss7Y8!pnHhO=e)OJW~J5=;EL#FUfEflYK`M?SOcmB?zzw>eL@rsZ0z6#Ik`-*jm zvcJ9T4?pt9ANu3#6OX#b9(IqFx+k=5;Ez4*KK7u!)O}9tK3DQjY5u9FK2QAxK#P&C zp&hF4d{o!{u&#T3x>R>St2?mP`#92}MY`8d-8p#w$Hl%QrN~h&a&&F)lTgFW9mP;b z(bK_}2x}rh{m==d)AKAa)r%wS1A1H=71k2z9`3hc!r!wdy}3={r7g)In`Rs(*6(bX z!GJZ>-TiKCuGtD0?Tlm%N_i6drVE%g4p8-Q$E}uI*VdQ+LC{9dmPV$)7b}I%YoYT+ z4_guK@q}1p+A1a}K;M~G9ydH{RleYLS54yWgFlG%P7`bq@Pcv7z>cO}pPxel9nNMn zXP)V_YDIntBImF!xE?y|Qn>#?DRfE;oqF&CEp(>nIYYZG(Mi!X-46%d;jj>czOXh7 zP{X>wc~s}mEf}hXrjE`$+7s-!p8IDRL-4Z*I#sh%3M#;{V8+3Mfa!7y;5|4%3N2`{V8+9Mfa!7wHM7##ox!n z*}{$F)}haiKE?Fs_QM%|p63Axt4%4;_>?|;?i88v9>7c+$$D55$DiVZ=l0Vq-^6#8 z!!26))ruF{#NZbljC*NeSpwKz)#Z-fVo z(;|Id9$X4TDdm|~WS`WPRO~vg?L5KKA&db*7ma%Pv;55| vtp(0jB71-Aw{VORx$-XYvvU diff --git a/struct2tensor/expression_impl/__pycache__/depth_limit.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/depth_limit.cpython-311.pyc deleted file mode 100644 index f9c3af530da2f46298d9e7cd4b7cf971cad2e3cb..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4244 zcma)9-ESMm5#ReDkGv!G{T<6be?_5Wla^}7A9akRk>ny!l@u~klrC^)cuR}ak;m*E z6N^v*r4UfH@Pi*xSPyy7L*d$R{(wB@PoRVYOdMd;K%auXFmMVK0Sa_xPozkFv{@Zz zXLo03cV~VxyZT!^t`aDP#(#1f!2JUo%@W>YUVjD5V`7ptF$Gh!grb-hi&9!D%4u1^ zKFJCd!|5>RW#E;x!ub&Jk#vOfVc?_bDDVnZ3Vc=6Z$_wQM%#OGVN>0Z(=n0UCZ@JQ z%vk&A6N!-L@C!;hZpN2LqO1DP_cLWjcbBPd(z3g(+c`Z`(&>XTqmGleOG#B-s+7yN zL-kd=qAzDYqn!d=RxjEnwWz6QoExfo@7_IC*C*kp{|<@%9?}$V%mVid-U&=F3XJ|f zpQPqsK15%+^6nOS`fFSeurmINPcvl*;QIv zqOKqLu!9zlIY5`5q<|5bD zOJHtCU(p?xg7b`(||arU)O)Un_9MuG*vc>6~;2F zsafi*xOO@9xy@Fbawbbtj>{@pcgE!>JkZ*=hLJCptyJ5=ljT)kYjYXvfdyx9T;p4xJ^@9spWz36mJoPG`cB#We3aHCjW&x3{j2$rHGEkGoRrjV9Qk;;WoAOMjd z(Qr=QBssy9H)3hU4ADsY_-Ip*C#&ys6IA zXcuVcZnoPrgxWRdR5P{_O83CNdSM4~*g>D!)xJYMBAeYX(r@qWSk(H6S6Nm)XK4bkl`lrqlS^SGLB;y4(=M#*g8Wr zuwfw8lSc+(vbp@vLB?4C^6z!>QjUIcd*k*Cd0)MGsl7dCIzvX^=NsFOD-_$+52CDj2)q!2K0zu6_K2m4JjL~ug%C5x7P_2Z0H zp>vGcED_>1WM^QQokemE$$20S9sppX@7Xrmx@W4Vjy7jY;5JTA0)g}yBN1)$>KAJp zYmYwNRYq#c$jjJJZTR9&Ea}CPwQ#bo^*x$?cmpuOZG1EqFm&K?0{(}%x$}%L9hYeH z{BUz|@P`4h0&Xa|d3UvX@qh32fcURbtpWsmuXS%;4ls(hQOCuI*5v-21O_lFY(ssx zbY_<&!cKtp-t>+6FXYM z(-Ji~!H?prJb@aeOp$h+2%XJOvF5&e5{Oc*``NJxP(FM)s`3o>Fb z?BJnV816P&^Y}xcdj4q39uWHt=x}}oWSzWHG*21XdT&=5uPNj8c<<(&djIg3H(&H8 zcKZ`gCwKaPz_7wSrXO?iWlztr^Y6Hi?)_3pRRJE_?nZO+r?YVzD+xVK=HgSC!v zLRWiQI0!H>Y0D+sc=AK#;0y?-AU6co^`r1UgXgA0-N4YjD<%8$lCh5(=V0iutFXv% z5Ir>e9tiKHhI)r7U&e?2bob%iT6hGY`|#!qEwQU5cC;x^o2toEK^*e%8N*u<@e8-2 zY8ddCuUN>(3(j7RC-G zcyBvQ=B$v}xpuHN$A{^h<#t~MY8g4DC^;s$ySlCU66ZhzpNyN&~s-$Nz#Sw(WjTczV4mA(jd_K2_!%n`eQNC z#qiPNhvBJ}FWnFPi8$b>@#7u;ZE(17q;VW@II){3{0?bjtX*vAbFbz(=~GGxbGF2s&#&Ka<10-)yc_P=T|4EYn@*s zED1A@3f|BZm`a$bkLwKy>+>n^{0+|X{x%Yu62Pp7$<_zniRtfg;D;kdxdHwQq-){3 diff --git a/struct2tensor/expression_impl/__pycache__/filter_expression.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/filter_expression.cpython-311.pyc deleted file mode 100644 index ffaa15ecd8d364083cc3956eaed1bd1bb330be75..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 20909 zcmeHPTWlQHd7j;!y)SPPZ=!A^>Ow9_lM+Qzlq|`*%aUc>bO$vqIz%uKLn^zq_4%SI)N4g2oTU3ZJ)Y|6reDWqThec z%d2mz^{*7MDho#`<;g7t6@V1Z= z#)Xt8<(=`&ddI!9zH#5II4;im$Ne6P_ss-mgX2LK77-4OhgjH;aCkh-!U2RM;}I4P zA{-r$vT$f7Hd{Ad$HL*6`tgQTL=H_injJ|+<;ZlC^%le1D(kJz@m6myHyKGaTouPx zuUK27qx{OXHKDdC8XA@SXT$?TF?5p7JXc2pEXw> zH>B3v+ep>h>zx!+8&K26$K=_FJQ8b7LTj5=Xl)bA)rnl2SIE_w3ZE9@Tay106okWy zoXpEoQp(5|CHc~vBIk1HY(~mYCG%1`hlI)WOkP&xl=Pacu%zsyl*~w3C4D}fNzO>w zjNBiIoS8x%DTz=1Vpf`)KR1(}kY?rlR5q29(o`QRo6V-?XXHIbm5Fng6S*`pX3nF+ zRD?G~u@h72nN+`YFehD1=ciCCmzT16v>;{YWF?tL2XlR>QqIZJgna;cMV6(DQ|XDR z$W#)YpP8XtmA=j;^HVtq18~GrsXpo4e4aPQ=OQ{8DSCPseL`Z@F<{o^`abJ=ukw19S3;fIUL{$ggdH{d#q(2+L8)7woy%X28(pFf zRcoAe)SY0w8$FR0BCsNso?X2If zQ6b(v9?F5}L35aV5aV@n{-slxKfcb5;%)0H{av^Y{Mo7rF|x6R~Vsx8TWLvxuY%7ClNpPa}G{W zxVsY-cgFjqPTOcY@d|WGu%t-jD8}Yu8>u@tKgXd9JY^vh3dHvq+YK|u*6w%AMNi5^ ze1A!7jocjc1pI?)Ox`g!+vj2(+bY-@$Ly8s+gK#dQ=C&Ov+v5wUhQ1L*<|JXD{%z_ z&M$%0**Ov`kY@&HN7(kphUSoc2K+ObCEO?2o*>C`F%9BmE_rfNp2!n}B+q4ElWlPR zSngoWJqd5fPQDD@$$B$m>HB?rKCmLL_-oeaUxLcM{8{iiTXMc#3FM(qna_f_5O^Q8d|=gdCAd(l zUZ1O`P(Ezd1LK0`fjtYtLTDkF_BiW>D(6Z5BL;koDQe60kaRw0Y9P=(z+q=d7sFpJ zJ)53MDx||0n?5m{%w6b522)NLsuHcZbWWzN&+#1UAgOE;mn()6$E3|USxTqmOg@dR z@BgGVl6ngGBt-7 ztJ^5U1svvHl3rrl+~_&aNWG@!;{m@2ijr0I2qo|d)WZ~J zlc9Gc_?jQOe3~!8QA>eWXb0(Ga4Zw7xp-KqLyS^Sq=5+8hkAeuOFhafXL!Srp3KgjnO zuIpytcCc7Cs@07yc9-hf@78tRsq0+YU#uI{>IN6P?}gj6@TR3LH`{NIs@;2w;eA?o zpX%NB5FHWpMz$}zeL_A!ZWctp1JYpia&2*g&+Wd{7$5}5&$w1qAq)!$hc%j zVJq)~8>NEzsFf&82h1=~D!34+ic;TM0j0>KbctZDVOk;(XTaxtAlo|KZEV|ZYuBDorr;>cDbnIm3%Ez0n1U@u0FAYo)KUaHN7W z?G~UMCKKESC(S6Lkh_Kfi$X4$c*q7R5+xC&DD5D+pN&#kOA+K?$%Q~@+UQZ%QZwQl~lz25Z9*~16za!0?3S1SPyfz+61+NO@;Wva!p797|c}Ql0z4B>4K~`q%*1Xgq%y{vkB7LC+vJqVfsA$ z8pvCow}re-KXUe`2GY26ERV3@T@VtMZj6;ro0onKx&|~4WqENdPiO3{{c!p zBw$|@?nfK6sHBP#=Z~sgA=rm=o2~zt(uKEMf-;0JTl0DlGnZEqiNgBoT`-b=OvMN& z7KKu@Q;l}AcH>c6F0vZ1T#-bAI}s6%B@(1H8YyT4_CBJ$`AI#%d4?^J9-0Hx%S!$- zn+64bdO@!O_~KUR`z;)n{FuYATI(H^1z40tD`Kr9Ps!yQ&MfhcL;TI0w3$4)u89$ z9LQr~E07W$@DyOsH0aU_xNDilhjvU@pXc>S-7<5Ge0Aan=SYIYRY}nQ5z%I_) zvN9{D(&T*30Ob1hn4yAVKgyT%XgZfr*+$tu=m1CzKj%Ed)haLqnE8@>U!Ovth@&yHF^yUI*JlOUgFn(0CCwNIc3 zteDv>`Q4}R0%yD=&&ac+4`I(6s!vVWG(q-pofuOH!Qy@fEXomtm7_!m9y#lf;t61M zAmhgByTO9Sze1TDVJJc4Dkn6aU3$7G4rtblaWiwo+?{*1Am%jl2cud(%;XOp2CbE~vJ|eUZx=0EyLtr))$}^P0 zHp~GE9Ry)A!T{XE6iMW5N_Y#z0k}P?x2G&de0@uU?*VWT_;`qYm88!1!et*K*lShN z?)72zEE4*>IP8RI2G)H3sh@g;a4W$yGQF8p2Dd6PVd+Q6PQD?wNq>n==+wX#guE^1 zv=w#&l+tPUFV*l6x%4#n#vCAb+M)Tm+BJoT;4Gu`=gD7SJdSorR(bSq=khd<#xrf@dh(vtEDydYlYwHG! z{+*hCC-LaMdr|3LC~_^RHf||~wrZiRYS+^Qg8K*r_Yny0`_$jS%k3=shcrKcz~>ve z7u`hV!qcUWl_uVY3m zggCPdAu57v!Gk?$m2z^%S;SxQUh&y+0z7VUsoj_CEUC2S>(B$^A=_@6dcz+0ecU{2 zrNPVOtlw(ORsYT!yynf5FT!|iI3ZN{8tr3T6+gPYZ_9dCxxoQ7z$x4r4{kDe5Uj^n z#Ns}a>691In|Pf<7@#-7FP%TjI8@Z}AUgBlt({9K3x}8#-C)9C4`~7ROnf>NA<8 z6l#CB^PTw8No~upy6d@O$0@Djl))R|Iq(Ju@tlN)tMJAII&VEDnm>FIJn&cp#^!~w z`(5C5ahNTYO)<$lFyo;xt{axk|27xK?>wg^EsF(g7E%>QDo@)Y)UJ!Edh_ z7pY=b{kWFpmmAT8I$Y;z!bKok?^%k-u3N22HOi}}o6S0rQfJRuX#W}(B*%BcjqKa-S< ziVj`wHqIUfn23p2$}<~?oi$Zlr{u525lV_=98_wG{n#^DJgGljH=hc`gnhNE2zTX1PXHBJOQ(^MOK7Ch6$!Dg6RV$X*zod<>{ zFg#&tTg;PYA!Pa0e~dNV|CIz>mbxjgQ=a8m<5*FMy$B#4md#9&7kNi9%_li+JH zj&yv^aPv3GU9a9m?J!dy8fvc{zYW?cD^X?)x^Db z#iMt`qoq*27U}@K7A>vctgRm{2L=Bz&jUg92ObeAGZGP+I~I>zJylx0hQ1c@uZF2Z z!jCN>odP+%U=R&}4gQ?>o)(zl1ARqClrl-gW&h}{QfZ1^+_OsRV#Yd}DLu4Hj!a_0mfoT$EuOu;_g<{^`og{T{$FfA zsoV?aGY5&=Dhas3_{D3tT>~>-u z)?vckmF#x+1~tKl?JdIVTVB3#s0oMRWZlA{9bY3aRrve>dqYuU9pRtmPL!S$uOrZ>*qi;mp>_9yvb^pqAorHWikIrg|E zp-9?%o=JQk`98phKZe}Gb|JZ?H!&B0?NUL`%_qvABHxp3nPqk2mj}YbR@6h>&~R<< z-DvzyG+vDMYSCU*?B&cs6Lf1$Rlgq@etM*$hH>SDWR)^SggGf970-e~oTmq0U`PHr zg}6))zQ`|>DRg?FoFbAVa*+u6kaY1hR3T-7-f1tefhz|nbP&YuPGkxak$lmHQ zomlpYKGHitVj*A0jd5-LARXIul-f3FZGB}Qg6zF2Y4-TKm!@tOw1Feqwxi?~>~@nz zp?|O(ZSoNg620Hhb|a{@_h=2hWj~7!2rcVbyIS+ka*#!ZgvL!;Lw7mM-Xg-PjvHO? z3}{U~Mokf6*M4o}S$n;_c7ty7_?njT)RiVT!GY2V4_Rn%ol@}Q6#uS51AgE0o>J$w z_u^$A3vM6%xI^1Lrfq+LN0bGMcF=kc{IY9q4t=;=+cZpd!Wa8k!nYJHksc;loFI|y z;3Ewt-$-lqTmD}8@DkL3 zVqjDYjQ+Ft=b^IC(|8hEI#Eb;(8Zol!_jLSZ*(lZqHP^6hDWsUi0U0-%jj4(;$^mM z2N0;WY*-+IN%q(m%RO~mii+H)@_l6I4j6K*#64^-><7&qHn)m<0)B`hO*8==2w(!( zVAg1B25zs>b{s5*4{6~;s`n5d53*PfK^HxL^hxk}5Pq{h35BH2XI(EUS#~j$oDKNx z4bxA!tHJ)entpD^(h^p;>+c}NZPz=+Yb`y-wjXO)^))Q{3JB1h7z+dzyfz53d_;gC z%SU>FL6p*9_kpICnMKSa;2TrUVPu7khBL%w@|gF^2t)hh5$+$bTmwo9#h7P+YZD>! zCs&-Kn2a#v!Dd6vs&eVnatgE)pItyL+$m7G@>10s&{ESG(7YDTfagoW z`l}0XDQ{l96WpK%H^8z{+eVAA-CAt78rWU-389UM=JJHBMh6;M%GeXS3+M!LnW9Ci zOku|g6Hwh1u!W|BdRB1-Jf%jvRk53OdQ7QeOIEk}kD#WSHoxC)^H;XBBi53NRjhOC z2v4ZuJh#g|r+1;U&7I&4&w{1!0GCT_kgS)SLc7UmMk-4o76XSdi;koM_x>~Y0CO3T<6n>ArWkG; zaYK$vE0VqJ^lJs*fs&!Zbvb^Kn|l!Wxd7n>tRD7oZyo8!g)J*=L(^`SyJ+XVgY3*X zu(hu!Zh~_FXKXFkXYRFa(AxTc8q~HQQ(KRd@qU6C?bM?6?YV4gIa7*4GexN2Clwv{dujq|8`8Et&o#-jvrs?1~z*X$pOxv(L<^f0~xkVbqTlqr0@|E>+yc2S7IS@6?)C z3g`h$dL~Jj1s^gkc!%xIEVx3?3Zr+K=le1ChlY9Hu=?7<-G;4q8n!M?7aMkH4LcT} zEd?9YV4G#0U;3oE7(J#%kE!CZ<(tcr2w%!PSH4Z_N>$htP;QB@w5-bv@1wK{7pbm| zL}-x8cZeJWv0Jok9}#ImiQIQVY!+?FJD_?8=z`21!ZIM{ zi#x5^I2mtt!mZf`-wOZ+q7G}8MLVro7G+tp?9E}#8Zm}7%idVMymo_ju4v7UTE^@s zw`J*e6kYPM2<+J76+0F|ryYBoS+FkpCG1#QmeNrPfl8w`r)O2OFG!doxB`SP9MIyJbCX)!yC zu^}xsqy~nNss-CgzZrI=z$xM<|CZl4EclHX8-GKyV$xqWQY<&Ri(#B=DQGU=FQ+s2CmAqQk0afY9%U)~~!v zBlUs6PUTkW0$vK;H5bRF;*)e2!UJJiXts7JGLZ8F8nb#JekNF|FH!-v{c9+!+!-!XOjG6)A=7@^y5Fk z;1@_E{I@-m_3O30qX>UrFdzoqr`^gX;V%37b^sUr88Hoqr{vO*Q|@fi<4}H%@Bn zc9aE#Zl^v?9J zp$C-ksFJ$9(eu=8c;gP01%yg__WlfmL+F0@fY!aI>_?EL*vPheWNd7?eqL*R>gGjl z;HcVqRBJp|_96cMx()AOG!*LDq4n%#!LopN8|fVO;H-8zx!r^DKcIw1mDD+p&$H`h zQ0v~0tL~m%pA7yS$H@q|NzCIp?77*c^&9|~BlyWnKYxXXs?g95>cMid)S+z}c|Zw| aDydVRos@eyxlzzIVnUJds7fjmDgOs+Z&rQ) diff --git a/struct2tensor/expression_impl/__pycache__/index.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/index.cpython-311.pyc deleted file mode 100644 index 6d2a44de8f53145ba93db869b46c736537b64c02..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10422 zcmdT}U2GfKb-qIm|AwSQQI=%Mb~2`I*|haz(_a55YuUU0TRVng#NKU|n+2EV3?*tw z4tr;~wipReQUmcrAB-aD`XMRwu-Vwy4UmThFp7N$k^lwT7t)x(#1uvrD1xGIH0%Z1 zpe=gNz2tC)l5B6#edvg~ymSA5&iT%F4u2kvN(7$8)_>*YT)H*Ss>&? zqL4IE1jSPnN}jZ*8FoMjrH#3q%>XMU*(qu%$Ec$6i7{fq(v_hghpYXtm`zt%jA*RgxUa`uv2< zt45}x=>?;pm9xc6p{%GkDBjR46K`#)78kQ<8LTL6s|6s+QYrzqp-CFdO*<3dIxiGHP|Gut^xk{#NxG`*&~Zloo+QgE zlp8!N%O7NmGxDKBhoF3AQi$q5-V{L~UjHjcDZP>YKGWwy_@$OoB`arG~K9_xC# z3i#1}`4e*uVxMMbfzP?H&cSz@^1ZC^bHS5dFWxV6#AS6p(|gQ+&MCXbnUN=-vy=6r zC(FIvQ6Sfs{`rUAGye02-%O%6Uk;OVru!X0kpW5Ms;=g%Mft{pT9%izs(d3`Mx6%r zuORkRD_YqNEm;}hrl?Fw=QYszo7qyOs7}a*ysLM)eSjERY0Kfykj*5y}f*ZoH@-vq{#yg4QPJDQKA*9v~C!tH<+ysG|X6(_WO z*Q#)^Nx>S(WsA8gps$+ISTh|8CB!8Ttnk*?eamDO1xiFjEcgyUtFo?9EYb@#beuKz8v)cDu4@J)rZ8UsqQGAHy+bQmc3k5O zvr@2pM;nw{KwQ2jkpC&x$$i~3Os?_b+X@0OccY`@(Ei&@7}~431}cEnU4di!K+V8n z6el^W!GWkRT1W0@;9AG(Ni=$tL$)(Cb} zR6Pff1o)KEXhx+J#9E{g7=}hM>B9koxH(*gF2XAw(VyO&UeHSFbVVsuDa|fTpH=nu z4XrYLL!zGi&* zQGDuQeCpo!8u4$K@ozlw3jHUa5lC7-F&cd0^+};;1d^6S!h@}_Pa13ypLg&nOh+ui za49X)d{n2We+#nyRVpp0s$c|g3H({a!)K2h5Ztu+yP%XS;)>4*cUptH0^FEC%OWVY zOLg1r{f6yAKzkG;_CR~DG1z$nSP<7bZceula?3t9Ky2F1?{4Q_OE0Xfh^|%HJ3A&$+VFg(^CP z?feL9hL#M{F?j956CK9d5F$?ofyEl?1rTnLLu>`iID;|9d(u(P5zg?TBREVPH^AV? zb`|MvsQ3x|>klD$8Ti?AqUjIZ-u7d~9G+?nA2f##0^*FP0CC1s4Nqb}&^6tRZ+jF!{4jpFe&k{!e#wkqvP*sm z1+!*+_KDBae_D7(Fm3sKaey;_KfoCztI027Br;}(<+Xve@7!Ce?>^B8&zRwv)$wMi z|IXHWXteGbrRe$R;NN~aiuSYxMLz_PQ%PEYc!oz2(q6@*ii}w)rdUWahu!}(u2TaViY41wSPdf+?AxZxd-;ikb^5O zQ|Mmt?aD1`f1wN(PywC}tl_!}MqDpkFCybcG-7-aEjrGnRaDBt0e~>YVw>_n2Sy@c zQefc=Km(8#=twNup$I0d$Pc=sS$@8}oK4VRLufsi-MR^x9}6E6!zPP_xGl;i6jNBV z$1`k;0@k8ymt2-b0QXqxxMF#X zyyb&qJ9@+l04YI^26SH3)ncBG!`iJ-CX*{>bv={OaYM4a%5<0RhnL!}&97ssGv9Ba zp*{)8f2@+nV(9x9Ze4gJZh0tfx$~AewD%YBPq#M2lqsg_VyZbfx_b83t;}tS3Tc`YOsz5c(NCqvbpcl%zkp;(+BAQu! zm(^^3PE*utjPqLkXONRS3L;KVhg;19*im{5U6ylyo+n*6gTErXJ{zZ_4 zhw4LbG^C5Bbg?d8ZdW>||4Z+R~fCLZC8cVg* zOYF6$4+`;9bl=P~%2P;N!M{17^C=uA}K5dII_}6LR0f?-Hlb zv18}X?C1lo5MmSAVr}9j_rqHM3l50f3E(R{bVvF9_M?sPF*AIuE*|5%LeROt1rZi~ zP(WBIlPPIRwTSshCWCXfYXX^!qUACfibrEfUF*g0l!{PF4`ae)96~}J`KEY#R)B|D zbc%{>1+MVW2%W_@@w%oJ>Fe0}JibTUz@1Zz#Y_g!@F0}3>psm991Sl!y3ep834#C! z-T7;IM6Vx25|W6=J9am1?nt!=ps=r(Ncg?u0p{S}erziDKT!01 zDyMjz)F>}9Vka_U=i7{!0j=`4iCv7*w~dKiXtVx6<($MWyeYBs^$@!VBX&MU?4pd= z^*L9sV2!BsLrvT!b^{cCvLm&B-k8`qe8LTB-8!%21?&!3UDtqcm*Oj6da>R)tRQo9 z3gwC`YJqBu8v67{ko2-YC#o&?;M$r4@t|w_Sr)n>MZ=C3fC;zhgBu;&To4h1JT&*Emde-?$m4pHO|-R=ac;a23X*UhLk z{1Roy$oM_+n}ulEWVTqCF9UgJK0h*V-h+(-(S2Af@NS$Bn0{7HP=MzGRh5Jz)aQ&l zt``X&DC{o4jw&(+!m0m0j90KJaq-Chy6yRblRO3qp-XhK;qLjcY-MfI;uwmxn!?I< zzd7)!lW;FgKRE0oFQI4wDe`1%VBM6KFT+L3ZyR zbI0WBSyLQq4j^m4v2}A1WJTT&JMiOghKe|*8PuIn1w)nY?$l%62eHIqZcsguv#kb;r8K&-5 zRK4FBa$L`rT_gIPz|rOY&mFPncZwef&R-+C@UiUg@_FzEY^ZRLy%N~ef3SN%?=I{F zYQSRPf^Ep^^%audxmvPUz2p-9ZpFNJH^r~*VhYE9aoDU;GQ=q%@+Ufj$w^F5xh(%Y zMAg9+_Ol^$`YP6-?tw<4tGLnCQ`9v}>dxGbt$|C;(O`X8i_x*#ffq|?hi> zg@m~Y#5hx)VCT#SInU+t8GK_7`mwP$j(5&o3wj!|^fgR+ z6&I;#?ZnF;OOHnp1$Co{QygTF=YIvR1R@kY2Z_^%V1&AM_{%MwIS(;dwA*<&*$B^? zVF(w@@_nH5KrdkOZJZU_lQutROTpw1Sg{{!T>RXJ{2Y_M81U|zx%br;fn0NRmpMAs z@Bya(M>{UEYlXTqs^zZdQv5*Ja&LN`yx5}T&ui;Jc6f;LIV7aY-%B%Vvq zPuT?h6esAY1>P4V{A}^?k|Xq+@MZapr3yp>ndij)EM!ks)b-05^Rk$4#zW>9X6LTh z4p+OWZuE%#rpBZ0694%vU({}JYknDL5U&90SmIUqF@#4UcC%8%Hs*bC9`G8@gt1lT zoBb*HDt91f=>n9YQ0RXLY0E1JLTmeiAc!r((raEKOxDT7W}nZ=fx7$CB)jYGPm}Dc zyFX3xwYvM$Bs=TwPm@g5-Jd2Isdqjt|FAHKfCCo|xnCdo9pDmVn#sveM_XRVeZC{H zHe>EO{$TfK$@;D{=8m&19~84zjwH217`waQ9Gh+t$URU#E3iy+?}3k&T3#sne4D&B zX6`)tK>RFJ-}#!k?e&%qidid1a$1mu<1Mm2SsVN4>rb%!S-0}MkP<+`)+hI*Py3%> L`Lhiw+5P=5O@uR6 diff --git a/struct2tensor/expression_impl/__pycache__/map_prensor.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/map_prensor.cpython-311.pyc deleted file mode 100644 index 208eb41f0367564f783ee9d68ee8157c1eb87c2a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 18599 zcmeHOYit|Go!{k~l&JUHvAvRCq9u!xEIYR3XKdMV{7Pb9+9=I2w031vrbuR&ieptm zmB0tffv@LY6T~>A09&}Ja`6Yz&(Ed z89sI?*|9H8TO7KghO;yOnfcGm|2@N>`Taft&v@>?#%gv5!hg_9;i=1K?*ENL5Z)DJ zVMLHc*%lQiY$LV_`-pwQG2)nTjyNY=BQBBZ>{0iGXT-zu4&=QfUY2(v?;G*4ybJk? zkqVY~BVRdE$?~44f1+xnisil0>X91Rr+CL}3&_Y7O2v2qBSEfIoaXn|0+=BC6`Ne8 z)Qr~`ipFfF+G?diu0d^Mp;oR%KTXBAI`f!gHFEtG$4K+n#H6Lr``5yx0hqMfgrkDo zctwz#%y_a?b)Fl5<0jdAMhI*S zw>bsjpn5KnRKx0JDV|W&a555)Y0{V)pOCafSk)9M9FrwAeEz&5OGzcB#Z^s8#-*?n zRl;M^SVW1+JwD&r3yKO*{3kEPrNrd9Xk=8HP?8tovL?mHq*y$uNNVK#g(R|<2BnE` zBEYmW*z|9Z-MMz3HhLg1JrQUN{qOJ)y*(ueZ!F?mWP%B-_lq- z8jW9y#Li3MXi`yQSn}5tO}Z3`My0SU^BJ&Jr@5T-VX;S5C5-i>cPs`LWySJIqw&~l zikj5Qo6}z9H!{C_-E66WNsJAL%me`?)VMr3Y953ZI;LC-squJH+RHX4p)v~SSuuVm z=8K)cCUvu1BAmR?^DL6impa(~bi5qkH7}W~*ltmkR9<=xr7!aq(eT7MIV`>2eR)uN zUFwo9^Wr1nsHX62a;zs3OK$JyZyot{>Im@VUqMw5T`Vi9|E?CxLe+oFi**(hsw&!K zG)eTTaLWBUe8-s(WD7QYJU%qF?0;OoG`0OoJ~hEw*!DN1F9S})b0{2*hR;P6!+S)H zPbsl8O49J0PLQY!M-9&z<<&_gHmVrTvy+La5)h5r(QtHhG76pz#Tl4}LLYGlms@zN zD!qlk^(b)=3*g(GXEL6Rk1M0eQ-*_>(C|_kA~?fIsb+YKLc7PPh-e{INsy?J4X2#E zoKTD^VqhkuL$UCLVz^W#IjP1*-340!-(LKa6?A_Wl6Qrqh_C!JEli8kwrTscW7;|G zdLOMmEVh)zl6?qmVhjJoCI~0P2@SfzB5@$EBSZ11kGWPNkS%A@=~$196d!S zHd78tvjtsZNg*srTwS^l8NI-Iv5bYmp@WLc57k6NvbEwJMPiA`BrmY^N0Pa?x~x%g zWd|`KuF5c8I4LE=7a^~yK{TQ@$L+I=Yzo#bOtdm4`m_2S+< z7|CwvP*4eV{E+#ppi%+6)f`%(#6ppYMDzoNC6O+x5ACB)e<(=Xe>) z38H=IJ{kz$2=D)x^FBDmqCdZ3*0+rHhhhD46duO<7Ff=Ue`5U-{r{d>pYKNA(O|r% zlJ8|^ZIv_}XF&_Jx(?;n2{_M!-H*(IDn%1U@n3u@^11uBp?I%C0V~3Yh@g>87Dw!| zO>vCFOXH7hS6qsF+*5dm`^~~o^bSYfJznTbusDwnrBZe({&A9P>>2aS31gz{x>7k( zExVN(rM46zDk=3bo9uzJSkE9#SXdyufN1a{u!2BS4KOGMh)PfhI@OqGVXE<)XMpfi zlX-R~IJaePKDqje6X7MdQVASufI$nwQ6>yojYFS?&}K7tOoyo7hQIa6JZZWD@p5-p)A> z`z2=5cT7SaOOkhG0w?IOq)ndNmQ)o*>WuU#Jt)YJ6yOfH1~nRmPUgcyAD@6nekv|2 zXStImoh7GyEDnPx0>@B`PpYFpgbXAYv`kNSUW7Lu3rMn}jVduY90Oc@3|+kvhqE4u zy)uZV(xf(t2A3u6LVOY$n>(;T2QDL7r=2}LJq)D@7)BE7(cB?=8%3-@WiwrMy~w(N z{{{po#t@gz0YQn`A%#`r*z`ze6b0RpO3cri@hvb?v7}*WZw1JP&ns*!!^NAGDe%!E zJAoo7Z^!@r*O0s;TCLyrQT8GJ`I9o)#f36KV79S1PDTZLZj08aqDU6B5TY^I18Z6c zAuh}xSs=^#6wHwZkV!-)k3mZ+i4bNq5l+JSaB7nX%NU+e2xe0<6jC>!rQr_osTn>v z82PA8z-2fzB|4_I1EOrshBp)%jfOQX6w+L%W`@fw6L0D!yicuLbz^!8qy7V6w8Kcg zfW6@noW3`XUOBqtXjpVKWIgpa{%oO1?>mvR3(m&R1e??KIVCxkxP*qb*~3>(WSiFG z>+~%Cmhc!F zJh&ZgByHAIf8Nez48)l#?@vS!Q32&6s561h(vk>)#Z_5RdsfVo49WXsQobvk$1GoV z5EehMU3Z=lJ`mJyR0M2>t*6%zFKRR=I1nUL0IYv`k9uf6_A~X*&G?y8gHyhr2VySg#v2`fZx>s-AyVUyB zV(U|x*28-1Vc?$e59|Kn*&|tp@9Ob4kH2y9%E{T2*^1iPlU(hZ$&Hqkk=qn9S8eY; zHnF`0PEI`n!Oe(sHYi*1if39ZCqBmg7V03;z^Fq)h6EC$58sS<%{FavxIrrwXE>#B za6Q%R+nJk=+uM~>+MZg4fNj?Pr~;~@VrbxX7ZM6*?O z*P}~Sy^B@7nW}!hs(WF|MRY@wNO87v-278O_VF?W^xbPtm+mt3b+`>s)NX@LzL{H zWG@n4HY&FM!QZpusR1{T*1 zWY+D}*X_LH&a@rS+YVr28P7r8a}ZvB{UH%@w%vxr( zPHkF!3CqB;)DH=fSphpdH?VR`i;G4KVU;jed67pG)}|N4k3hDjB-NwTziG#7xuP#k zU3<_{t^oX$y5dWDucG1F(Dk9Tqlw|mmM5^D%vcIThe9x1LlZb~ic-EZ6rwH5*SJF= zIX((Ym*UcfujE##L>Ziy-FZ()C0E#R^H{uQR9QVB!*MPij~ce*nBhEp7U6Ja>#9Wg zDtXF=<5|QPIk)eocVEeSQb`)C!<0Nt$qA}=oPlkwo~CyPCC^i#KW|!(D$!^t6cE(_ zO4*8Xb*4%LN(t?Tmd08`B^Ca1HlH1)G9@h@yL59zU)!G(kju7pq>kP?o!)We(@otu zyST9@8{DY}hbeo6n10PsF=q#0&LymE&$)SSYxnIIo~J%oN99R_!(NwancsLjptld? z1iY0em3Dhi>cae#9z3jX9?l7PEl(PRp8i~g$KFT1_bpd9-*oFOTlMO$oRbP&LgQLC z6}@3Qq0^W12sQ0`b!W~?FFv8J?dF#Ef_m*%-ic3ma-Y8QM5#yCEl(Omdu=L7@ypus z#3kBWQqSlccT+{n9hrTVCoY@4Cfl+3Rv>4mT<@+sZTjw0dhcoW4ll)CV@f=ty*br7 zKX7}u-oAsnYJP--(>QIh;pb-o`LL2jY!F2;;r92)K)8vjWIk&tSp#r;h^N_O)#2td zml<1Zjio*H%^q`m*>=9nN|>=FiL>%gDY!(Q={WHQdth6jztZg(s0sVj&{^{E!_pac zGk}K~*hzK%fI3_jg6PI!KOA{w^rF_o;_kosA3OoK`W=j6cya9@rX}K0zX0_AKi!*l6(UEC;$~*?T=Z}RxNXWCtQc8l1J=p!C z1zJXR=37B!mZB9>6k@}_M_kHQqZ}8@JWG^qx5DiZL?KD1J^6=>4(lq9+GGd)SxZHz z{S88&pgAttUbnv}ToOg$Mf8Q6cn)xjb87Ena|=7avhWM-*>6Y~oCjwqZL;N2G)O1((1KHfFC zH5!ox8MNFZYqc^J;aOm-NdW#Q{P#&^!W9slXMTNmbMpmH*shoN=uhV{))f6|_5I-xh6$T&{w zj+1HjS$^zBzjk|d_qv-?8TVG*y)`E|?fY2*R4X^6_AYc~D)#9W`*Ib9iW10Y1pIli zMX91J_r(`ql$DgdY9zKkw$~hz>>o>xLmS;6ZxN9)cVe{2f`BP139>DqBsAZcRnoc(2Vy%(L!90u#W3{ zfSMJiebe4ZnTVP8joBbH#HlqnHz(YE2^E(-Z(Mx7SHO}^N9-d~87nXZ(u zcMshaV7o>(HOY+p!$n6eJ7+hC>(w^_hs(}_rXmA^sS;`$ma4lJtGnipWvX}T)jMam z-fLd>_V-d#3wtun2lVCxXUc*q`Nr1gDduG+5J zuR5+e$vI9IAn;ux9od#F5(LVNOexzf@I{z0*PL_CF=s~{5L)C)<%~N?OO=0!krCW; zUvperamR+~k&RC)QCLKUJ1BJ6K_?&@_dU?%T9$%WV*@SDJY1{XO8M~RurtR0jT(Ze(DA}{-Qm`u4$i!TZX-8=_ z$BXNC&2^Ejv3k_JZn8a9)aiwZ(zivIuiBrLQO)qz%9z-y@vZIN6sJkgJEpo$C)f!T zHYGtfTk^Dylc?f7m%T`4zN_jN4PANqF~C z=XYlU?R)9y6K94v=%e?5$Eah$Ugl8bohFTG6g7d))C-i*8X3M9=p1#Jgq`8y$76;a z7i$7O^$G#)A)qSDdT2%sKaU`N#V_Be+X;4#66%#->tRxmuKH*sFq!l^FpGx^Fc)7P zL_t0hC&~7Y=;#Zq8I)n`+NGA@VoPwrooU&nx9mzg8ncC!#QH4^Bgr%l z>Wzc5r?Sn`jmxQ?baP+Y(Qwb-mR{2}|5A2MV7^^n)2*-Bb*ENeGj!)$kbYGx27q+- z$g{`!=l;Nf7l#O2{W;68r;?Nt@8>t)@-X zAD}{`);LqflQ$^UW`d`rnO2L;513H2UNsf=x_6#@8x|Te_nNlAjLg|7J20dVJj1fA z>JgIPwh6v9HwWO&4=?PzGxUiQ3yFFAx-eT`7rFIa+~)9GE_WEsI5fw>Ic{`XFQZ%R zG@24q85EVGQ<@ovMfhV{{8{~o2+meXZIsypCBZl2e5^igkJP7Y+WAJywCjx~y3&~! ztkqV*S}y2fXJCgGg{^vG+L1D=&%#tKS4T^?RG?~10f!=R**WaA?NOi}CXnfbR#Pw< zaq)`heUyAyES+&%X7tPhXH;xqHULWZY*G}bSn*7|;V#<#!lVFEVw0Wbk||(mU1#eX zX5&y_wz#llUGK&Bud`BqT8BkoB_qu z_$AG#WU%;U#ifX>8MfE@jOx5wk8oIo?l`L&wJAoym?8{#)wijXd`a~kN=WIcBqIzD zp&bHlMqNpK70Xq02>^RZKHH(#A)9?dO&u(AfA#ej zm;8Z6e_(!_zIE@NS2O;DI(?49@NRB>`|?ur_QmGyS=U|Q+ms(jO(=SI-Pc%&N{0pa2>qotr(+LfWB-_}r)EHcB3}zbp^v1r~=skaf?%$9)v9L|=Ka$-T1lCogk*oT0&O)P7 z^UtV{&x^g4D#~&!B(bu&3Cbrlv@g|fTdd!fsSoP)L4-Cd>hy~CRQG~k?>(ID*al20 zgTSORNUIlQL5yFD!Qs93U+#4r5#7IXh)6A7`X99`75*0x@=5*^1&b8;(>-NKg5pmU z%9e)2#Q)2C%I15RL5g7D9#=$_Iw`Etoe1!s6NJCwhYYxnC^793{{#G(!En&+gRk=X zcR_~#A-w)o{d@8e*;6uBP5!;fP=tS9JdY@}q3>kT48rY8M=7#U>d$H%nc@FW{(Z~A zL;|;8ix(~FiYhJp@6wr9uVQ9KE$6(#o%Dyf^HxX-6qfs;rT3CM9|a|-*DzGcUsvD4 zYrxI@a7(WmGXP31I5Tb3zoga_CxhdCi}Kv@E^B#%YM8J6Gs-=pug$#0yrXUDa8*b9 z0+}WY=>F?F+U@C%enhD|dKYT-j_nJ~!QMwIc;7eXUaNmi3-d!HtGLHxqpOshra7A2 zfwL^1co3!>=A2`o+61o0^`EgXE5&$7^!wtlk0#5(4YlI_`R8wkaso1+;HnM2vfFojv^!@L;BW3>6XKH&wtvw1?1dH zj$2J2=Vo9YVb`9VH=i9AkBYx|I{)eu8e88!maAZOK4C-Wyif1kdnftx%jwRi^^QV~ z!0OHEx+bwJwN>wUGAAH+_wX-{u}pSD*Zkqz$Mg+5b9R(3ZwW5g7oPtplzjw}oc;NW>CH!U>1eKkDrhWI(*J}Q zoVVSE+lI_t+#`e_?%1w(4CS22V$5}$^tHIsL@xwpO7?9jP6atW6=cj) f!Rv~tpe#QXp5s%Yd|@h-Uv(-}Av_>q(&~Q!Q!>bv diff --git a/struct2tensor/expression_impl/__pycache__/map_prensor_to_prensor.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/map_prensor_to_prensor.cpython-311.pyc deleted file mode 100644 index 85689438a5f43bacd8f5d019d359a7325722049b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 22795 zcmdUXdu&@*n%^aPN%5igTe2))$+9fkGHokPV#~4=`5w??}yf^LgG{%DJ$3y|s!%)NzM1H}OCe>Cht z7XvKN-*?V^Qk3IlcKW!Y4$r;keeU=EzQaGPtBVS_E|>pex_wj-{yW`RSFM?O^hbU{ z_(YI}DM9wh{;Y4#KjohbOaffT) zPFAZAwfa$OV2fJ2l(4c(*{SqR2b?ud4az&+^OgtP`^@$*B=5p}ce61DF~)9p{fhYx zv06i@6@NmlxE#J9#P_ED0$?((UdrUvl)5D478Es=&*bJcX{WwbmxbPY>WOwB1*bLv}CPL*Uu z%gkVAQhHI#W9rWQqWM|Om-mD@<`xptg`AYlp|=!G3Co^lgX9+T3yU-XNmbrn#4P1u zDV5Dy#hH1OMqQ;?_$iF6ypx(+$ST8S4xrJejkjhK> z)LXO-)@E8=Pv_LR(P;F|H{Xm(l5$0v&*NL#FTEa>rgJ%DFQrsg8cog1Xq=N5(_8gJ zX@OMdm!Y$_U_N@X={ zZbi#$aX#snr`dMTWU$=;9uHswEHX2h`BYZAn#s?ySM1#pjRP>7{R&cR=l4t4d1vFJ zd^@ohbE$=-i4)0u&dSoJSVPf@7E<}y#H&b#e%Nmg(LdbJI`4lYPBl14@QmkOnX7t@ zy+2GF(8nN&aNsHCzR$eSGp&I(l_$<>2%y|={eKdviq{x{@%beqd;;*Ex|CH6aUzq> z85tCJfDIXZEU-nG}Z06P^i;5@* zm7o&3Tw^tM$gCVfN%(T4N=Xe$q9B&mC}84l*5dRr9hSpjY;{V#98tt+KNzm~89MyJ z>QtC&U}am(0WD&iS~Qw1s?H59YB#lL0vmVD7Io^u9-1+y7WZ@4y!~yYNpc25 zf~rsvu-k!;1fMvO==^MMab^}BYni!BHl<2nQS&LzUB>e?5e*&CgwYex_xYTZ&MK&Q zHMK-+MtKMPFHNN1CT!@9Xe}c}xpp4D>0tlJYrXr20Op0DL{dFPR<3)WnE zDR-6L5Ye8LRM{Rg@HUsyxy7uEg-DF1XF(=u1{Wudm}`fOAln?H-rSgJFa$7kBWy8( zbkIf%0rIIJ1n~dpGLlb(72k@Gw2N2#xC3VMYbAg?s?V+!@jd^FIPc2{cuGC+r^02s z{(F%X;j&$#if4rPeX{QYL~KB~Kh=(kCm0jM+UJ$45(gY9cj+=b#RPFz1z+7_qV&Qf2z(D&@EAKd!ES0i{n`JlG( z=9xlm@49?@9$)Np$jZ>|BlIU2AyomS&!=q!vx3_fZ^0{9jP&iQpg_}mBPrjlOu zxlbh}NTceKN*WN`$mM2p_O6$)O!Z<&R)q(#20hkO5PQ^K+&}%eY9Fd&WYXHyek87O z4Ut~;a?BiK(P06dYAVb8yXq8)RG`%&SyE>no&L~+qIS=(nTkbTMwW;JIbfXbGtqS*J zt+y_I{G)qqLmO>F#kM_q+n!=qMOrSzIbI8CxOI`pm=%&x^b8H7AB z`mROMuEpabr)C3~heORk=|ZHKrSc4OITA6zrWdej>uZupkTejYdUM0g`2hf7)A)pzxs}?utLAq#OwctXYDo$hkBgiVa-R!qnw+9F zQ8^_S@vqGQnpf?=vcEnMDhrhK3j;%?z57eYzgK$UL^;wDc&;oU!OoNe_y0OS1hKa4 z_XW1E9V!d>a*~Gu6kR{e3NZw))_&`ywbwrVf!;bu72BP}RlD7;Ev?P2FYCLH=)I#< zQ*shlb*UWk1$L|*`t%4D@2E^rIK19<=g^l&^uA}QaJVuFMFR0#T3Nu?mI;3cZ0D=- zXLSeOM_NQ;T{bFb@dU&)umCyhnkf6O3oGKfRm=13C07%I_x@TiZzHz=E(kf#zY?s5 z+Nm3$2)10t@e~qPdQsJwe!&E@liapKlG(4#G{H?6T+%=GZA=r#+@=YqB}@~Tt|L&o zDVKD0Rv`({BO6*`CCS*5Y3N`TY^jnh&Rcou+$d{qK66Wv zu#%=W1C_C^BR-POi|t=?qip^9*Bipg4PG2o>}ymEbI}f2q{RgeRK%PCju!RZ?wG* zXuamU=3nu7xVzuWuYcQrvmHEp3=Lt;XiIt~n4DFzAgZ4S41eN);d{&QUC{_x@X{IG z{AF~E)|$@EEgWC&^E6G2Wpl8xXvY&ymAh!F(GU@B!4!zXFYYNX*KnAq<8J@ zVx&)x^pyjiDvVNCe9tSz$NfeSMu0huk??q^8YMC6te2OhDVw*Z)}U#$gIEkrdsiCJF5u?3eozY&f&S-C05Cc^tHR5)(=Z|O(m6BTCmQe)ZGvvRp z?^P(M7lb^uHZPeeq_=&)#g)v1U~(_9*UL43_j(y%!cYkc#vqR1-2!pb)f&`~ z)C$5YAi{Y|Qyg|0YA|$g#r#{Mg+q<3NsO8(XU~L;=$>A$&SuiHFko}DA51Z`FqaeB zkmPA%s*ARnHV4>aU^cO!qZ&?{;t_h_coHyRRL{Utl7dZufE>T2{~BkF*ye{rIeP^6!GbpMky+Ax%0- z^v?)sxkVM`%vv%?YDpq>iflwp`xPvjMjdwznDFhfYuEtwKQNOIeIJHNC0R#)^YxL> zODxuHG2dNxrGlh@qq@LuB?#+?=ym!0X81<NJiC6XeJdM`U+lA!`dRZgh9r!aB?d0Hfp71005I5QiR^mAP{i z<*e94B6tncp=ipsmX{ITp@BzG8E1A$XHiw=kI6GkVnSDSape-Q{%3wuVA*@k_fcSd z6CI!jD;0F;BmWkZsG=T(n_)}1@vEc>8qKz`B$-o_imK*R2o>WfQol##b=IbFhHwtg zG}1pLX?2zBiW9h7-m`^)uZq+of@o@|(df@f>%*(0yyuBQ9Pvyu^Y_uj{ z*Na?=YOlUaDoIg)ihu2ANdBDmW!fjSZok)ZXrtxO-R~D$#`Ko4 zdo2?iEfb|s^TU9zwe_)pWHs^g=E3zJ-(4>4K7^+&&*5pya}dHr&kngZxS{7>Lt>*L zQEWJ%HykJjQSTvb|6@v4dhE#RU@6kDx{tHmS_>HRs-hUN z9^UE4gBod3))!-&X5 zkS9hD=TYBs5zy`uty@^kW})kV2rQvF=LTQZ#n^)<8<{TpIesv~s zvz+TRo*CZ29M;efQ)zNTWZ+1au4L@FZ>D=G$+d3j9lq9g66RpXj=4wW(!@JVSwwUO zR7o>T!WEp7q`4$$Gl*oz!-hzC^$gyHT93r2f6e13GdeAt9@i>L8cukCO{O4n(ky3Cvw^7FG%`TqZlS(guj^eMx*u$SY`a_U zK2QuE)Pn~}ydAp_G&_5Ds2Dq{$Bq`nqjsf(#o&k@9HC0b?#Bk!U%z{}7#q`LV+C>S zexz@G_-?ouIig376#PdRUZydeeObW&==*r|A<)n5YM1?r=m>{$0C%M6v1?kqAx!&Z z@p|o4jS`lF*DDM|Qz(DOHx&Wshg9NJ%lpS&u*=;*TuLOEw1LI~$)yC}N|E3_?ZQ|_ zE9m|a<9|M{Qsx8QSr$2jek(0$xbX^%!8#n1iA+bB#z^1MCR612RMk`v@p3nCH- zypUN^MJ=-opQsfp0`5ruAGQ`XgPajbCezuJromiHn-6ktaz<*t+~+`@7eH-4*KPut zJ*W%952ZL5eed-3)Az*I4Uug6f2}C)&_!~mlv*XYxX!KOkH2=bvoFfCJ*?)vVvZCz zuI5px((KL9YFXNHmL>-8!vK#Q&(g7i*vT1hR6R(&j!^OpC1lXxxLWfnXSZsK9?@>7 zqezT8bHma~Hk(XhUO7Bv8Y1VD47yk{fKfFSrPvzlujGHe;D5d>ih;iM=WoA2W`e#_ zTaVthw;Vv0-B(R&P1|r6VOE#RvdbT^r%u9F$0m`+`1bXG2EAmAw|rpSboml-F@rCj z0>6SxQu81fO_~7*0l?3GMFd;%Z^fuhWV5!2x$o7tFl0Qal6jir6O?p-MIM$ES$T(> zrf6Coyt+!Y8B|bb6&Tg;z-k2@w^8a@f|~&K*QZ=eXu;^ftny1@EByQ{>;87_Vsfs;^3;x3u z*wemUZ{PQ8z#f{!{{`$(Ff1`A3fO?j``aLoNNqg0Ljokjog2th;*Q(ixKd$ptnMpz zV-NBK-t0o&5mhms6a8&smCdJ#g>kH^;aEk>3|SjGQ`Jyb=9Ad8SU1{Q~A-7Mj8g z9Pxnq9-6A}Q~NrT_3(0!>juUFvLEmg+jgK5(qEvLdJ>sc;pdT-Tc^I>e(>vO&K7#k z6(f^+WRgUN`2A>Wp>20DI;2O33gS?yx%1*Qtf`C|k6bYjFyNeAVniNNY2CAXgN-Ts_pV6D9Qsylum5+A%H#vv$ah z?s>5iQEoy>r`+rg0Y*u)yQb4K&VkU`lrGfhcDHcV@XM{Jv5o02Jpi(H^*W}vd~#gk z&OF4CLUw1iLQ`CITZe|Bbc2LY^eO4x8F(!n6O)#3U`zb~y;k?k-$qF#<9I5qCg5l_ zD*;F7Lf&B;K8J05>GOu)R5A`0qiJI@&S)HSmNCy6KR%%?AFxE~&C{#iVjeTn zYB2;N>kx8z_Qvso*u^G1iEv^*fEg}s)Bl5^Nb z{|98SJm2dW3NAF)7OVufVpVx+ZC=xXO>E2MBCEo-f(|OJyZW*EBQv0LEWyk4u7&b8 z9Ynp5x0JKCzE?S4lQTRHgZqlz_A?+D<8=y#WCA=_g#S9-k~+Z{n)+LmBU4knhNBf3 zjn&_wfnt9DEq?*N(u|UkgK4AUFoYMqNjtW3*U%RIfc+2pugo z_3BN7_nHoDG#$9x`p4bHrZK%~Z1v3jSm#=IF*cyb1`6W96A+PJ5ydW&L1*$5^F&wT zp3z~=8sR4tCJl%KAtKj$%n0Z#>IpW{BkS@rYO=sxbI{mCers zw>~%AdZANo1a$e)gk38>C;t|E%7fw6-$OA(8U`lUk!R2Vs?F%@@8f|+wXHX6pX#Bt z;LJ}dyS(@57S1ZaM{hcg{t#bU({yX#>&}D4ni0K*Oq0REQeD%{DTKp~Zq)56gm$?$ zCF(@4dbWmWJ=>a}B15Q+B_qeLbqu!o`V5c0upeLXsqm3sAVjpEZJ}mWNTt{OZkd(r z{qWLbE{>W-NV@!sm+2rNMJF+8MFx4$KgRkIiUbS085lw5Q&KoCc!>!Q%q))673Rqo zz#N(FSrkPIlBi6kML-^P#eGIYi1}hJN5?!5H>w@-%gyub-ZaFU5hWE)M;ek&R+4~1 z=Qu6@0c(kq3MxCIK!-WW3(EVGb8zKT7#KT{$^D{}pNlGEUBlQ_PGktSJ`ZrP`uoTm zwP~`Zn+)A6IX5uLnU_ZR??ZdlF`CiroGL>wZM%w|QXt*u7YMBNq2YY;#7wy{dVY zR0bpJs6QCI5gDqP^U4TGawY(8I*+AJY5%+ejd$S$`$}xq|;(S&Rm{*QWICASTFt zHSjMG6pIXmwsHX9?51kcX{yDYqyszlzChD@o|gczfThlqxlsZCwIK{9SkM zbC^(ABQAv%CpN1nthG?)a6Yq|DrzdMjjV>tk>yeMB2f35pwhNOajnBlNK9szQwxe~ zU3G;LN^HyJR;vYmXrax$7FQ3X#zvvW2BF5*D4m!YoynvhOmNMMm5BH$1^W4f7OaWb zG`mpJ=C0}V43(A?)*jT@?r!0#;g{P`qnD|)JHWR(RI-jN|Jk@Fkjpx>&w^!B5^*Iu z7|cBgT=;Z(=@>(DVtv^$eqt?%n*akdW4~A!#Za6)xF(YW(_s-d%$1DLVrYuwg=V42 za|o2q(FuI^%L05lSwtbGSR@dyHX6?C-G;pV1@a3Y7G%7r{+H~Yk{QwK3*_KsMgQRGeM+coS&Au(x# zD+BcQ@~N2_ya9Hqp6j1g1QOl}7e-t|@~BdDf*)i6WaP}>eqd4JW<)kzzU&|@7cUUY z24^Z0sLa3v>()JG2aK#%d}6!_&szcxUkHA9f_|ogg^OW2B-B7Gqked8>&GZa)1iwh zSo(k!`mX*BV9w-IST;-mj3SuSTnm#2EYrq6MmOp|q2vQfa!Bm6lx}bEc1JWiY0GW4 zRpSqcUH$-V$-3U-@dSSxyGFfb*J$4XgU0#NuV&X!e?lv&6nSWw%iF51hVTBLU@a;M z`kcXRGGmZPM*ojUUnM;b2Ugr-J(_R6yy`%8(cfEVQ=qJG#kWr5Zo0hf$auz7f)O*$ zj{_F7Y%q`mQ6n8dR`5e7b2z~Y6Gs|&hcJ+YCm4#xBAe(V;zOSZGg##t;x)|TC;pp3 zGZ>GDQTv^+JGT~xLj_s%&JPugR^4%Ky5?3F7LDdQZD6>Ku?;D-IDsAIHMCCBPsDiiUGGI>PCFs1W{c;cR}(X?qc%Y#oh>wiTvJi%u8fX$Gr z14lQyjuyL)>0QU}b$xH6>wCqn3B79q(=OJX)ay>Jo`RGXjD7gZ`>(urj>02Kv8HMf z61E<&j1fLWl;yVtx8KAOf~NtwYrY$P1kxaud&O4?egsmJ9{3P38wMW|GGr!Z5E4c3 zCr+w3Odrq^^FT|$ffl0%nkWTMnR=N?a45Wf`Gu-mPcPTB(m9k}s|Gdyk=_ZZ6Cg>2 z*gpPQJ9y*tt#KT*wjOZM2{6!6{{*$we~AQ!3IY+M)*Om~ruwh($egGd?m?P~TYYj- z#_ALO_cW<-w8ErH_02c4_v#O9)E_9;kLdLytEZtVaz(M&vR7}}yLz@{%@Y$|uSHUf z?bGp0Bq~xPhRVPqIJEA zalq1}LQA*ae7wZg_q#jSgd!N&LpZlt&*!`m3{UR1w43F#Kaf;|Y;iHh= z6Ta0~9yGUqvj3<1*Y_5i6MA!E^%YoDYrEFA6=S>f*sg-Oi|vf(brMuEhz%jbO{L6$ zSk+d_RQ%=PCoZPKBwE4%>N@&nL#w}!oX6g0Q(Q7K5WQXTEyB;l3>$n7e-8z2=D|!1 zTd)fYdTd+ff1#ZWir`r~Oyjk<+M-pMR)n`hVA$rhZ8)?WdJu#sc?4FniJ#?|zd3?dNEx z{zK}8NCr2RNn4uA)IXw1L^8P_MWY-!Y+QjIAJ8o`l+__;h_B)l3a1*xp45Lt3Bj>? zml9$I>XtULBUJh{hBDIInezU3=*<<7fTLM|_^|;xXirIPhixoFu7pGLFhdn~9=bd8 z$GO7JlX~~5asUMnU_1-ryDV&#j6A`3778@4citKJGOqXGj6Loulg2nsAJV&?C8OS;$Gk?BZ1Wo3c+6}30XMHvp1N}-olh{Y!76B) z*N`Lg8onx%P|z{2v8>y?Mh_?~ry{NY?&dZ2o)6$<$M>-KSR;fZ)G&IeM4mZd5{2S+ zRR&uO63Jg_)Omhqhr=Lob#`Wj8-BS}m-}0%@!M7EbJVcMd60xHCIx?f3X%D_d@gb2 z0RIh%5dR4!jtk7t`#kl0m=XrWFHw%-y*LW}DdiaUyiU0pN>Y?uq9jd;Ovx%G3MC9j zZcy$zC2S*@+?W(rr`aQL*qUltlHJn8{dD( zN*=q*U-ApP_m&bPrSX%cW8>wT9X=d`MzRJ=GQJ3(`fm^ET?lo??SuY3cSiNS$MpW= zCD4Q7hGQ)BYn!D+jF!bhKZQBPSts$`*D~E%=-Rs?YAASb~ zIs7clr%UAk<@PDs8i?Z87;_;`yd Hj930|zzA8w diff --git a/struct2tensor/expression_impl/__pycache__/map_values.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/map_values.cpython-311.pyc deleted file mode 100644 index 1981828fa104111448a7d1934909f4de0fa4245b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8911 zcmd^E-ES0EmcRAU-)*<;Zrij8P&S0P!Pv&&Bzz`-2_yt3GoWPByQ8AgRc*J^Uwf+@ zu+wP02yH-$W*I3F)^0RvjSxd7d9YS_U{;Kjhxr5As7AC(vSewr5A)UtRvJZG?VfY1 zySlnF5X&c;;&VnV#t*>JYDs|lXrPSZWEOx ziOQ*bmdo);J|`rFoR}1Io}?$|O?o-(6SBUXKj~*>5z2vNfR#N^4km-F?1geD8DeE0 zl*7p|EBmt@$xbz(1?D0R0o0%tnrlcw)R5+ZdUvBY-Y5_?tnq4x);ZVHXiD?0-cIO= zHhR>En}^z^dEIYxzt7Dh-Kln87n89z`Fh;rQ3|&VGn~8T=BBnv^t>_z)hNi&`@bSX z`>wc7Astp@@ZLVPR}0VeJNwuLwJnWWT2%XBOq?fYh}wUhs9W5-dRHLiYxvqvazOQ8 zAc?rLizDQ`Qq(15R+Bzfvh$iQ71B_mnZmp-Wi{nvy9N7oC8tTL*-TcIjDn;{dL}=U z)g-Fv^I2mw7(9{9N|}6d-q5AoylzOdFqtuC6ZKi8s7b>`g=%?2&g50?iaqQ@P162u zUdbka#YK2cVcq~9Qbw0%G9PRC4IJwIAZB_#pE5Fqyu%0D7%LDwKV!@y$LWHBLMR5~ zQz&Xw!6~`~a!4?B|Gp{7(DHhL>e8?Rt1lJ{y2e(hr#AA+_!|s@}E+&S>?m5 zW(D4$g+(oYK{G7>J4IxvWG(*%c)gZSX_n{Wd@-vfIIA0&q~^1Vp~(d{OozsxcBKg0 zro$|#bvA3VqN>m$78Q`5vwScLMVq!nS|}J+sJ?sHZFT5{d79GXbVkdnx)pHNHfniP zW1*;7oq6rD-7V+g>?|+UjCq<*dFyI`{{rwuy?FA^Kt3g(a<|D%t_l<)l^FUi2TL5*rtYx$IYv_!A7zRlap44HfKr=HL&?ayNxCv}}BC@@4mYXF!Q-)$rEZFZs{*ULaqA zE?WIFnjx#2p3?HF0&rHKvPNlvS^?}}7?}_)estX8(zFx!(+(iI2t<;WmcCDp%@%Un zSW(T*Q>rYCO=m(zZT22(3f#)vp2h z3$ADn>5M<>7<$+-R6f4iF=2L0EDf&(`pv*#)f=n^d)7LKnhS0#t#2K!d0=^uaa*3U zz44Nst;?q$`-I3&Xo&2rdAuF|$AT{ycuIiOI!P#6i}-@k8u1BHx(%fa)lC=IkU|}3 zJP!E&7FuqT5~oKDmtv}I>UfQF18xb%mmQdFu*Cmu5X;RuFs-_rvKg+#mjnZ4s6TU_ zMyZv2Qu6tNA%O-nbi`|8^w0kVPr_>jnUS1ENMR~%1rb;5JHraHOF>2}ry)2K zD~g>;zK~zYf!Ai^FPI=g4n<>V8dM#D8s-y2^RbbmIiK!F@+kpQBfK!lj3edf9K6gAU*?Q_4u$_4djNN&McY-O8ODI- zi^}=`=bDe;`=j0Rc$L##aM<)kB#369xVy28NKRiuGK{UB&D5FQ4?1cPN>#rTf=#8S zBhZzEua4xeRpn~@CDh;qYH(sbvF9mNMo^PSY?v4G1yyTtOdNdwS9twx;>Pw8FzCU&@!r1>$z5Rj%y5L4dFQ7>ES*$3V-kKPA%Y#5{lP8kgXw7q{V}Fc|yfs)TU$ zwx*<~q3I$EOcFjTn1MJ@Hx!6^Xc%5>FA2>@z+45SY5Ep~%v?HGfYfI`Jra z=wbBG?|1%Q{qErJ<{r%bZtTI>YV<8L`qrxdoasMT;m_H)1Kgaav3E}vL7GqDq%@M` zAlU*zOHxpIP1HPd&i&WK169y`s;K$roIA87#o&P+mfFGhq`eRXajN%vI2lrX&*BJS zk{^QjaDb@2ka`idjyd!m+2eX|=vnh20TOYgYdW-U;NYZWTpWCE&gr1qdEJwYf?P3> zAi`1{UG6or9+4$r`qUn+znOM`o^JQZ<~txEfyt6Mj2LjwaE;(uT4U=Ca)nF20O;CB zcfe8?zeUdi9RmAmYQztAI-u$NGEZSMj}p}tqCxP@?LZ!4Hl-00A~QQSL+n8h#c>b7 zS1HUxa>I_2VT$d1gCvkW0G23`oyy280!MfUWIF0TizHcIn=QL+N16ND;_IQ+vb*iy zpzsBUTIR-syU^MUi=P#qs9baRnwU`%T2X^*K@6sMJOtldd+^v0xyF^an?!aH zro=Z`_iMp0e>$r0eIWe!Fz5tcQ{hOWqs0^8)W%G|Wu{kiP~!UvX^dZanWp zzqIbBH_cXG+`H*|Hy`|uz*b@!mT}di^3d+t(C&5H>>O?ldrBU+pV+ia9=8kM zW{>jz0_%=-wI;sMyPR&>c{Ds`R_MmmjjH=9=N6$oV*w`BMfzLFXaERGKWTr;gSUS2`OloHZ3VbhV9e;y&~YGE zxSqXZiEH{gYm(iF<+CmS4V-ZX$yp?C;=o3x)=sB8ReNzdCN%&M@UfE`0p|b~xBu_~ z-Gr(7&w+qz8z#NSVeePo2eKMFX2y;!ov*SrmaoF%R%4T9Y;x)RfA~o_YK8{O`&Kw} zXu=$Ntva;7yz9=`%CXAOiR#d9a|q+KFve+NjMF^zIL&>KSY|ea#)LKmy30?$!H}Fz zp-?!YzMTHXavZs!u5I)(a>cDe9O$}X3uTE2D=BuFKo>q(W}+cx zAX=5UdAPRdDX7{v4R0UR`_SOlsb;1g-^_kSSls2`sP-!nQo-09Tsn-Ew3HJSe~hi5RXv2 ziYW#eb{GS{95tdGig^jkXaCHjQ-o#9_pbfZ484do@xwwPOW(uEA0T0A8S#Rl73tq% zT|^=Su{!FmKni|1kY%_#Pe3a(m)C9Q20dWr24BIBUxK2*wcq%w@uJ|Z5hP)cAT9r_ zJZJ7cVs1NHBT%gl43!Vu_Lz~qH4ip=N$B}**dymZ|NM1Q^zI?XHk+HSOc24Ln8z?|`JMp5RZ<)IeG4bxJ?NTzyy*$K& zL%Z)JY66zVU%fwI9-21C-(htO*CD~}5I-mMm*e*iesjnigvCMaIf$*1cu;_-Yz0Ki1I%~l^O-_1&15NZrY?1@vty2Qws4th$U?>0;=7PJrdh0~ zk<1_&M6w6T2$ESOcvaaEGgBk{QvKI>#}rS(%KyDF`pDjBY{eyU`9rS(%KgO%1#mGoB{pPF|Y7p(?kW^hMM zz~bn{HyKu1k8WL_`tnUPy1V9qj`aga?w|Uu=MVAw)8>J5tO=Ae-1*t7H6L_?TK2|2 zf3p^_%R7hf1UMqmt)a>Ih4>2dcKS=hsr;_^Fd|H zs2Lfnd7#~SrCpA1U!F6!A6U8g?FW_Z$IUIT)qL1NTFH1XH*zZp&OCOE+>bqovjW=m z$vbbBkF41CSs-=yh~@&Et; diff --git a/struct2tensor/expression_impl/__pycache__/parquet.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/parquet.cpython-311.pyc deleted file mode 100644 index cd2057f7ccc844f887d1eea11c1deb2b01f8fca4..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 28931 zcmeHwdu&`+df%P*oA38S8d9W0N-`x%vLs8?gAz&0mQ86b?TcRT?1*zI4n@w;cZQP1 zCAYI%WlcGBqHMKTZmS}b#CEu~(=ve*5b3t{ANeCFns$b`EoUz(ltpT^2~c1R+o`cY zivGTH?sJB;DIrrRi&-;AubNKtAkYB(xRs7qDwdV!lFX=(Mf@*Ar=xFr`om031_V0o^x!MPl&a}Dv`2!YPa^z^i7rd zo3b3M#c-@fbc*_8v~*KldZ=3u?bSl>I0WHc z{O8x$K9)xV@@N$ImAvbTHHrSIUM=m^el3LdI>k2pcZiM&yU`{~*cod^FC1X`w;=yk zv8N>egE9Lxp|?HyT@3!@nj-0Ath%bvpkc?M0{H8A&`j9h^bzD7E6xD zW06@YIVFy#={sbiKJ}@;J%~LRet9;Up6t5{(DQQVP;&gaSm!IfuTU1Vuf|gJm0Ctj zt>}1kdVFr0YNNfm9i5&NQ#~)2mHEm+s{TRb^<^);EeFFlqUrI;NGd)r_WF_2L{dUS z#6&77rNZ$9tK}(H0*a5YHnW$ky*$~=hhJfx>HjMl*3n+K;uuWaRh&cd@wDQ6A)ZPr z-Yc_d^lEfkab2IoASkY{Ch**AQ+#@R@ldTlM3U@-6e9h`sNl&-{fS?XT1f;OW}>qg z3Ram&I;jQmCNMi49Tz8)(=k!PW9>~bq4%oV--=J~3&kyo={YGej#;E%9)xI;e+*0^ zyzcyokT$)Y(!&ctBgXW72l4Of@22cpZi}`B+mya!7VQi6tW$eu{(T1z@9Gh6IMV_B zL6~xD;g~&Lsf7h$$vFW;ch7s@Ipx)2hXiX*)_T3?drF+CYghf|Zux#?K=1}$_k7LKOEX`llMDJl18!c;{z zt#3-v*~vZxCF+Rlmce%+hPveAQvg4<{Y1DaEC{bOF52$fvxq6b7Hrn8`Yro%(>38| z0w!N*A}&tHA~)_v2#NGM6no!c#Ytf$7?q^tok$9JYbKha`4J9he);u7lgSzJ&}?jG zPLiT`4-JW_Tj}KNp*usl(7Wd-s?A0qX1V zFjX-xcmJb||0|)e1 zA5&lb%&r9DsfZ-b61I&g-to!!bPNl(H=x*OUsVEE`CJ*nR8hP$VmeBTRB`i;Q7W;v z%{`?!W@cxVklF$@QD;Pb*7D|n%8kcDCC@qm4=Tomsszws2lhRda*jOzm zKr>1mGKp4xiY5!;;qy`peX!>VdR97!lzsH^G61^r(JUdV@fWx9`K zdJ46z56@-#3e~$es}F5dAIewv%hmmvgCADa%2geOnuf=f{RM3p(COvXZY*jDXJ`o5OU5@He+eZSyPZ-P`_%tZ^>t)}ewLsoCAo~V#_5rDl+7K{iQA+W$ z79&RS8#F`lsEB%;ny+0t()<=;tSECqSg{;7nxdW#zv;-a5@>*=@TIgpbVM= zNOCHCUBfXuAVYi$bfjUU2@#kn2J*D!y zb!<)<^$x31+aXpBQ|lcR4g*cKN7Pk0#2J*Y$J3L;I?lx!jLs@G5xrw{>g$%C!r)0X zyZJIYM4`g~py6PBQbfn7Rp>^xV46&Ms5}wpe5B(@`7HjYjsjqjbP4`_+3&7hT8|g@ z^{<`4^!5*6cKZi1U0c4;Z#KQr^q@COB0%f9_N`+hc=$MLOx>EZ%-bm4VOQBi=P zAdJ~#wlPP{4u;faQ1h6haM<>@~gjh~A0dk^oBWy%!ww#eAY4 zY`xdQ?%HF1u)Bd+Ky=}K&|Fe1Xi4deg^)4?Awd|ch*bhsidC_ydnHtVtQxV^WwAAg zttpGGT@qfijn!dk)nZs*ir%q~e)MTj$-%HXMZ$L`Nj#FCg0g2hX$K*Co^d3nZPfD>=FU|Df z+oE(gJS|2i!VJZ$kjlZj)E5WXSrNhk&o4X))*y~*CGRA{e4?qfLWWdrG(k+H!2@HV z(#qn@Zn%fHSxk&4V-$@7I31poqSHtX26`q+%r0t%oCq;s!iN;B~UaLJ9kYIP&1uJ{_5-1H>Je-z@8 ztU6zdi+93@N_d6jY)az~BI?M3|Lz+fM~Ng4M8Qp^QC4^knMiCf^EuB)LeC3Biph+u*kYsXV#zd*VY5)pn5`mOZWvd7(Kz9yHE%w!r__@8 z)64N!XiiYp*V&4im`>hd_#{lzz+5XN)$_ljG4;s_t&H;N@@1QWO6Ucs4cMlij47z+ zU()4l%~6N5e&HAh{e2Va2Wls;s9utmx-Wcv5iG|H0XB7BexCljU)-~x0wEr}vSJIEAk zo}se7@N*26;&aMkLc9ZHfZ`G+mPBu#1R;h`M-A~Fq^1Fpp`qn3c?Yn$Ny4{?`N^_+ zDH>0Ss?Zk>e~ok?!w{3CQ{gyA>zv+Tf&vRCAl-?tkpu~ZsrETj|8Q$K0wg-t#^ zFJN%c3ChTXqR@3I-gP{`(O#e8XSyLKPbk3%PfQaALPnY?IznosisWiZ{{XUL!Eyn% z3VDA9h>yPtmB3XL7~Tc`bg==6R6*?Z4!I_ox^b$Mq4lIFCV@ncdAN*E`ZTe}GB z5+43|9^iF{RcEC8(q!O`Y6 zz*exj=n#C>%axnnT^rtAD|2~ow@ml`)#0MUl9tVQ=o7-oN$Mcb381*7XyT@*1R%(% zHB0q6B@%U%kY(_dI!1+1W3>UMPN2z+`J$>FIwD`x%=u%#O6QP*B6SS_9r0G-3&C*C zxtl8+g20S~vm=CxM=h9vRx`s4e}n+=0-$yoc)=2@;}uIR<-smXtd1Qlu{zghiPgEk znC+1}Yr-;G&fqRf<~U)L`M-EEI;+VLI>*Ck4#=>fBqi>fVzIDpOV#C2E+VB`(oeYa_lo7Y7w1!3X@zb`>Mk1QoyM!j&s83=fVfuIEOt zA3ef;jvrH8BQIPTyso&1u6*^v3&V>4!Z)uEM+RSdY4DpMUH0^Z;=GYe!U_P?LLcD= zA9+xRC_TGl$+;WQ{keJL6bI&VubnLjiRNC4hE>xSFjC4VOfr(5VAE!P?{;%r?>ar* zZ&POH0oY`&Yj}9#adqv(CLOdquBm(2`?#U$VZ0df1hMJ?{zrx2@&90a>0sX7D7zbr zl^$0US1ssL267~VOY`hw31CMWCKyy}K`=}L2Ros&Z#YyJ~3=+6N;39z`fS=je5TvM?s0`#k>XMC zGJGPACWy$^!nWU(sjPN+HOAhoEyuszlLS2(w*IlS@mh1NwT6fE>M+>!$a_t@t ziy?Qg`y&D169PrS?dmS?d8!`kdFl?*)kE`c8)XDVnxIlm_4fevlTu>7WxGoAcYjLX zP*-P!LLiiRuDnDlh?Q7_iAg0cg8Kj$AsBSfY!+H{EZDZEi)nq-VR)e&*=6NcY(mkVJTW}icDd=GTPW6Md?6}2hFYVhYCzP$v)&-Nr z*IH|vuwakb=R<;k-gGTGmcO-AeoxkOJAM76wK2>3rT85>0bSlletH;d(J{)5N`FNY zuh+*`mEs2XB_?9h%ZQV{O<)v2aRBG?wRfH#LtJSV;bKFtn==Itpn9c3WdX*Ma|sA& z(56b_ZN(N>oZ~PQm8n&OYEDE8QuDjFneOHU{{n@l{u&@7Z25zki-r1z%xED{DF^ms z2j#%t%uvDUdHwkxJpZP&{94}GBs-fx$RCGl9`yfse5FfnIg)QUDmNU>hmOgiW3*7~ z7+A6uLe+ApI~$clJ(*a+?R$Od2U9<;dDxV9x5)06oV%sq3q7cNqiw0J;0^qy{|*0x zV{&bGHX+wOo%bG>y~lIj;{|WSN=M$?x>Ql9tX*<`wq0G+xTUmQ*^%v&EBo`#BeL^I z&Uu7Hm()*N9<^}6g&NhKbGDZ+Yd_l0mUR|^oly8q!8&CpY-SY`gb=|jgkZUv7B8lW zj@qrclrGnIPy@pfY)wmI5}RZe?wl|Q>B&Dn4{%wqp#k8|n-^`{8?Ix*%@Tg)Yr;=# zHetCXhRmfnG64mq$*Kgn!sgm!bXNSCopIw*3NOr5`Io){7}?2ca(4dc&T}nAd8AAX zxAsw<@_QEGzu_OT+_v7!HVzAWy56)eSFAW6+|GM<$=+Q>hw;*!&>UuiE0IKtHN=+O zYY0o{3DCNe5(JV2N=bya1bJQO!(Y^ze}_y`Hvlrihrx>F(Ul|l;4V41D;EgA{j|AG zk9(e4js9`%+I9KqXY>7+oOVc01*zdu5ewpl~38v z;X2q<)M<(!n5QWwqW;D@xvWp6lCt$J<2o#pkoYt<3CaAUPr@opLOg4i!MdwcuDWc> zx%rfn?&2HsG?QrVZ!5^C$c(8!ww+=BfwJlX;4>^&brJW%@n%$PtGFZb=?KiIPxll#u+_YcbZ2lHJQNm9M%e0|qY2 zfr~lkMP9Lb(hQOEDH4GuDl(Id%|R&;fhH>Q>RfbMedCEl$ag0q5zbK@;H~+3b`4zz zO;8+Q!kJP?a#AZ_B0$qg8Ygg*0I>^-7k*7pMy2mca}@Vo0)LNQgj8+QxHvr>iGW$W zgjmLGahX&)Nl6J1xl4T&)DS|X_OECUIKW>MXxZgx%sR8z-trX%1lOJG*M81}h4zkY z=UY9GzFlY-C4(@XNoSwZykA>{j#sO1?O$mj_?Q053c?Bo)5n9 zXB9;OkELL5m7{5OyeQyjOOdE2{*G?x?r*N#N4n-mfAx-WaSinhR0CMV$n1h%n5hQ$84LhdY0w_XU72E1>wLL znyS!hKUJxPn5$qF92uQ(gRIeFO?jJ`Nh3Dq5v$DU#oDPl?OTg2)y-Tj>yh4)0#;S$ zSOeKq#~Q_^ZLFxx%wF086j-BFT(#ISqrE_bPhcxGq->@rEb&_4YI&v*xK73dRZVZ0 z4Wi+zJdRfLwOMieplU*dGLieBz&;2cNEl)hP%WBt)GX7{WLi8GehzAL!yI$F9|{6| zu1beoMNc|*?oAO!jn!(wh9Rm3UX}_4K$TBwQ)bCfO@pvxlq17^urV>qaJXk1%P>$cz8Yc9l-2?0>FetwPfMsf&^3`1`i9&V zDQ%xPEzXc0o+5b?)xN7Gn2FvJBQTIe;7|s2Ju`(QrJmkC(yxKTN%!!qwE;O}6w=3> zM@bWj?-S`%5ARy$I2sPaszDy6ebQW_NBx*`GO6@Qhd)lpSi5xU{zq8CUu5ZoZ;cQk z0?e)j8)UvFlKHX@EzbN4d#ByXbb|M5Mrw2U`UeY;Z;i!O`lZ=cN}h6tt&~0KS_|Y( zEh@XL%r3syP#5yb5>sAwfq&7{z<=9j^$Ui?Oe$mbN;hdC=;#D%o3oHK%P+Dg8%&Cs zX-C+QVomxOxJn}4(8s{aOBMx3)QJK4f6mHl7^+|y{r@(|C$T{?eZXjMg-PKygo&qA zkeM(Jd@C|z4_t3dYJ9E)DACL*b(*K6ibNPHQ>t~u#+*avBvCo4CoD5PMnG&-BM8G; zI-Zne|8KxcV7b@)q3wO?UH#yLqK2dpYkuEV~cq*uB-UFEc@ZYd5n@!a~{1ERdPH`I_;Ph@f2R57oP)0(s%-XX3 z@G5idh3MtlONHD*6BkB3$3gCN^-?Mlqze+Vxi;lD@7n6p$ zM8Y>@zC|Z=maYXmXv9|*-3v}b`<&Gk6y|b4Tigczw@QJFo&}GA6&Jk=t{v zuQfmG6R6Tl3ohVNTXY_sqM9?pQQdFc#J*@&6!@aSdx<$^aGlvUW)+K>?f-_kmMs`F zSEhTJJb^Qg+Mp8Sj^RxIM4iI@P?-NcaL`S-@ZEurCu53A*hQ)<^q5RU3}apxrc7q9 zl`bPr3C&L4O~v8N$t*I`45b?fz=x#Rk=UGnhC*K<&`IDF0rs^6BgJh!MA&)f0CI%# z8;F|kC`Bk`m@4LjPMM+(oEk`z5 zj;syjTh7ZZ=kxAC**%zJcOlT2b2joN|0C4Pve;>%vBloT7W?*Uy3(a?Ebc`|Tv&9P zb&U(3xU7p0sKMm#iiPer)ZAboJp;&LxzD_N1gs%i^Dh4L z3kYV{x4;7}x(H1yIB(me`ZSsP)hlLO@Q_XiN`FA>w?D8g?~XaJ^qo;BGQ2n`j^Bb; zq3UiW!DK0>u$>^x{27=R5Zk##LdYx^$e7#2Hj7^)U)-JAoS2B3<~3xJn0}SD`FVa! zB>KS#L-#w{)`y8|31P|X$~5m{-_5t}@F{)LH`*JJ!st!YDjlWAD+KlspcTiQE2M`A zDLxU72;}l5D*mtO4p*=lvUN)nmm3yo$ttV{DW(9HeoPUzS+lc|x?H&j zx-h1xe~NU5W5Hy;a-U51;moD2U`1w_3=Ajb+I{&@59|lDJohjFvx2civ;FyCw;b%o z;tW(QpUgxH)pc@pXGUBa{QhKS^24g;l>_;zy>b;Oy503GDP92YJC|%r{RN-@LHE+G zrCpCJnpPUK{XcEZ?HkTF59cdJQe+#l^h+luGVMLfjM$Ov$3MiBhsk_*M0wOLE3xYzhH8h-%-*;jxootnM#=`X%ndY z2oTCP>2BJn42EdFj!ksBboxfXYzMs`NRzLSdKq?d16^6=KFNPSU`UW}1Y*ubcfuZX zExMm5S6wD8$@OTuL8}Aoxq6MJbScwNa4mY$yG&`TwD6*L$-fkU;>{6rV^g2MR0F^0 zOPu(^QmXabtflDcvuv%F!u)$bV4W+AUT}#%^t!G(Gt1RR``{n#S@16S-ofa+tB=qG zNx+`yXN0;{KW9SCk*ij%730;kO@dyoF zmn=m}r|m6dE}3v^gNyt@7Wpy=$LuS;q%jlW*vj6h?#U|u6Sn&QF|`+6O`*8) zDD5D+0}pT7J;b&HaQ5hyh%C}lbo>_i?C|XaoZsbIyt3-Tb~9lWL<^Wd)ck|6WWw|O$gQohKb8?m0?bWPU<6UKZXH4!Ul}o zTd-|h@w2(cIa#8D>Mp-bMip^e5jTIbl)iDbYs6BN$=IpyqNJ`!4IvAXfdxm6*8p1; zH6S^G%1wXQhQBL2w&s+(pU$1Q^yhW*iP60OitN7vA4BsOj3D*aD*=rCG&b5m$Cy$A zS0rrN5Mx7djpVjY;>Q&y3HpjZo{A@EJIA<){g+YrAjEGd-WM>DNtLYlIpf2VUSsM6 zikAqBDOIH-4X5K#?1*7sx-MRN?u8-4#R0p(hX+UC-bHy~Cm}kQ1tn6+^#(bA@sALe z9ufG51YQ8Z?r?2HB*K!4FM%T1Xech8!){9ZI|P0gAVs+LJG#_x{ydi(ejBm#T~9Jf zr)We7`+fEq#^lqzLUWhgbSU4{FE{n)ob?}i{7dKZo>tk@y7KCVC!F(y-@dlGd%arj zJ-t3C_nysnpOd@KebB$rJ(TMnDpbPqWxw)Kwn&h^fpyVff5!EIawd!58J=i!hT1uaHK?b>;I6;j5h4 z^1bfgiP^eua>6sSI#%LmQkKx}eVwea-TT`rNlKVFU(9|_Kv}w^RYuAk7WX2DIW5Av zs@$sFhRkOo-V09b{TIr1^t&uMyM9CX78ub*lwhdl5!OpEOJq|4EE;x|w8x(|@Ne1K z0Fd*(#^gLVQwQk`kGX%L8UjNDw($U1I@9jRY=u^$`IndqgK>(vV(u?sAP2wb3&u)% zIV+ha{Gnb`I`b8CZ_kuX=?vAB&O)X4YTtv^(;1{CAJ{iy-aOkeZ&n5g?2&g~H1kH9 zO^MKDn3=6K1D;t?w7v2t%Fj9|Mb9cB#-b1`1ej(pBf!TIP|0t@PR>H5S_qKXMvcVU znIH-3!+1O<25C-0gwHMB8;ok6*a07#)!Ew-%>_cydypkg9taNbJNuP4ugjtKm1sT` zmP27Zp3y{0dujce+H2OfSkLe<3_n(UfYGvZw?6nqC6JrOBM?=`IS;niOn#=n^bSl% zXKuuzXXl^(lFj9xeUBPTvnnINW9TbuufBi%{YbulSneOr2%O^E6vr*`E@QDqElt&V zE2^c=TlIeqZ)MM-xbmxX*sVNR8P|2hm=Z*IB_ix9qa|f2g762Y**mgSzv|>zrvCaC z@(1+>k-K2?TcBI0t}nWTQ1xbT??!NMp{i}Os&k{NGkYvw)gxE+a;UR-A^R!xQ%dEgy(>6MBgAuOTUk&@iV`q3DoY`~ z!&8LRb^b_WkA4%8^Lw6PekQ4Re+a<2qYAzaB=#t|8kKh)&R6!s-BPYRmAO=?Z+{qp z<7jQiW^MmQZGXP@s9bwA^Fq-pgw9}fu6D`$!KDo3g9CDKfb=_659(K(a#Qc>G5Nrm zeCVtkI-41JT-UO49Eupa>qhc*BXZqH=EcWHkFSSvM@RC`!CKyR#a_AMuw3!Qls(MR8*U3b7@gd7V2G9bQNM%m2E`OkZQ!or%d@JveGSdH zs*1SKoMOjhA1Knc5-xj50xy?+`Q|#3(V46Q?+}&DK8Fpq5)=~HJ2AHmRP~YOp=S*xCt2 zEmeaKJYuagr6hF$M+ew)J3AT0=6y29OyA{-iuB|hYXGZBc_Zjp1gaAA-@yqB*y#(c zf~pbF`qEU|Hu@RrIVN6{y8J479kt=RyNQ>v=xMm-o%aXXyGq-_IuGrMC%{*b!wys* z%%GQ!W25vPWBW3=4(!;4st3E$4~FARQNngrN1-pd!|JaU%tnQq`j~1&Jw^rYL|h&J zMP-I>CTTZkRM(wQtDvA~;=cAVx-fyw*ha0PVWP9A;`&A{t~fDzgi~;At*W#t4pMwK z4xt6?%`VnL#m#gV&;h7A47MNHNd>Z9Evl5vC1lEO{=!Zq zpls#r=r`~qq7YeR2J+h}R5Z$=y_u1%M%dN7u4)FB0)tKbeSz54wrh>7G4aCXCe z7Do98t_O1<+MZpY*&Ybp09!^n^#h2k)mk;mIYZQ(<4zI-$mnk1PR$YCCOyErGEecK z?sI{@oo_AaMf1ylM%_a+7~3mo|3OF2*}=JqRw+YSf1kh)2>g8l#GOm`2>g)12mzAT zIA8E<6#5MUZxZ-@0;>Sng-04v$wLvO5mB7ipx-ia7>q9?vM#Md46yj3fR`q4xX$scEqkpf(C=D${o?y~ zd3YOm@TlX8tx#M4@c6@lqJupgIleZxeywn3@Pn>D>-~8I+J>|2?5vZc#19^Y=7FUv zx&?1&>H6|FS3Gjl;amu+lK!HHawrOAaLU%<2xOm?yPyd|s1%IYS{+y%Iyj|z|AK9? zIXX?gzdVA)8i36o?yB?W7$1uWpz{OEAc(Y48z&G7L&5{wD`!Z5Va#x=>&Sv*`)yT* z_QT5bOK+QP%Fs;c+yY$foIB)2bz@gaUfnU9S!+Ss6XIOnLE3lKFG?-MF~=h)g3%sY z8|-4zefo*&p*s+k)Z+l^q094-?P=MvD@#kN?81R&SQzi(KfhwmZR_+e(3?f zosQ7Q4m*`5dA&>uJ&p|gjf*C1uP z^gh6R8@G7zLp2k0_61B)aR0`p5iw@9q4Z*S}@IKXY4LI@x~_Jy}b;D4C+9hZH_i$VpI`=sUgu(L!7TWBiEfF zkHmB2ka(`>bBCcB0a)@rsDId!_qOG{ZJ$;M!G;Hkw{N^P@m4Y)I4%c}q1)4kJvNW) z8y|jivwq)3{l3+TeEkWz{sgr9lm~z%55M56$W-nhALlkLaw0${VJb9Ju9CX7W>MV5 z8lh$PV^Hm9Hfy^#YP+*j`PzQDwjVX3)Bu*$Ucq0JImBkcIQm|{G=%=~BSdY#F~(5Z zlM01j*tp^`2u+L*1N52Rx6eb6?C1z4%%HI;W2dwCNQ&WyUy~bw1jgk+CN44|mWqXmRXebk zwwD02aP=`e(T=20h|ChQg7gY}tJX?8qpokK10#FbSvDk!GvuWeM7B=#{4>QxEl4Ro z?F3JWyEY#}H?uP;$8`#e7~;YSxv|6LzvT7{C$=%rx9noa$vL$r+L&pUGD_+6&f6*Q z71fDd(CPHSKhI=e+QIGC{Un<%tUL_(kwa!Hld;3gT%PqCIJ6L{Q!j&STJTuYp>;YR zW#)+t*BF$cr?7sz=w)1?x3tF=eP(-3I;_=Vn++e;#ycYf_hujXIE1m_GGxtT&%=GAOAfoB5wuii@F7|I9^lpM#|060@41tJXPmLUCefO1lK>T?fe z)qU)@V+6<#hwXm_BTzrg@3p}P(anaZHX5GFHyn~14(TsA(!>|~Me~j|`~m_hE7OK! zFNCa~#2kPIQvcdIhUkQLpf=ktZGF9V#W|IXC%7_v9pC?NZLtkJRBA7BCD!LOzt?OJ zuKna2Kl;W>^wD&_{(xM6K!3&Ag&M`p{v1FG8*6nWfy3DdrKZR`G{ufa!3JFQL|ynl zfcVqtc_}zY0AZQ_nm|f%W5XCmR&g`USxO?ILutGs#c#%OD*mvJm7k^4$d#sx(L^#q zO1siGRXAx^eHCFRm-ADsN$(gj9X-sw5c$#9JlQY`WK@JH5j1N;fnt9_fYt*$JWf5L zml~o}8F*eT?Wafv4{!CPCZ-kSdq0Snsv0HysNOLnh{TaRS1sEjtp^C}bM?Q|y!<6T zAj^$csBOZjbC(MZyV>D$bOxR2aJfQVE23Wn1JN*I!(?5r`*|yV^7X@V{V?8)7TR{v z2hD}5x`(tO*9DFPV9CjJO0Apj&J8z=KEHqZcTfM$nYYg5-Tkt=KgaG58^dzr$@Mz< znbCaX6}j%+U-*#1QyUMbD9Bh<>YP#SRId~;Ayv1hM+E+1^@^z=> zy3?5#aSYzS_wU%Do@Y{ z^BJ~u^6w0J608Qe71b+T(_fRWov}jTQoxpw3tspJG}rV?tO{WXQw`NjOZ?@Gtr^&? zABUs>w{>#?pTl8=RsM>ru|XdWV{q~3E8qPGpq6m?3DkkG zgkZ0L=1~KM<0qD!vbTNH+qdEETYWX}?U%j%aM?B|`W%9G47sXj8P+ow7M$c{aStfj zMYUkZiG>ZCt=p#9<1y`wLMKgB#l>+wEj^$_u4jJ{5A%PQp3qMwY^0v(M|P|unyDjH z@Qb#6D?{1tH|#B6 zHb)q73$h)I>g3is#-ox*Y*R!J(I0E!vk1%Z|J00Wv&ALP3S+i zU!E~KQZHgn1`(RLlAls~uJAt?#Mm3|qUyQI@d)eX?JXN5Rx_^XFNymgJrH_S`&YDk z$g)A8rqOW@YpO-=I4-xHC<=Hiv~**aSJ8psRtpZ6demEVQMmsY`vq^(+>z@zB6pmU zn@$%!c*Q>37944IxL2L4Uw zqtjw%z>zBU+x?1@2)yDM;yMO?0?n^e%2fh?L4Yaa*aRTWEAPz~gy6Hq57Id+(MCd| zs;Tc`#laF}vlZJaY))v|0K$T;HYe2ZyC9s(eVKbpsLN@0L1@ca?}D%=XT1x;ft>X& z2wgerT@ZHVtam}^$XV}#@NCX{7lfyC*1I6|=d5=@Xv}GM(cNG>gX7m8TwK0ebWpgx zM{d8!g0RBbFn@J0ARIec4Aj{U!1QK2P^hkdFeNt}kgE?C9eCcVXqGFmTIju7sH$T* zdg#e3RM!=K6!Hrd4Y`JWkV)mxVIwP+9|wJ}*=%%7&dxt8HIJGfwaLx<<>n(F@pPY9 zlizlH!(oGQZF`uj-}e!F^@%ybZTCT2O#ugejV!-VbWj*)6z$2TS1+#Bu3mcUd%4c1 h<@V$3O;NyW6Zp1mz=joG4zi;kQS2w%ykhh7e*q%F(SiT~ diff --git a/struct2tensor/expression_impl/__pycache__/placeholder.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/placeholder.cpython-311.pyc deleted file mode 100644 index 62b71ca0946dbb2b3c8d03b346d07fae8447fce0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12019 zcmdT~Yit`=cAnu3-_(PWte0QoXG}?!5+{yr#WvQpEG4N{_Qr0~FbO6!XJk>PNbQ|b z9CPJeZQ$*C7cS~8vWS4XfPMt2gSbdhYytl$3KZzC7AQ!A3ljqvHBh8L|IttaECLkh zckYnG8B)@^L4l%I)Zu;KbI(2Zye|K$tt}!T%~$_x_LUui@b6TpCx0!oJ}wHvCxRv< z1x?g^SurW%Tgr+HzNBwKN=gfIQeN;U{R@F)K%_o$Hn&R$PEjBA_Z8zm)Tx-`O^IeS=y&EkZXz9^9 z_0aq_SBs=a=X>krfqF(j|2X>h`GhG!>$)ju-R>DblLX-l{M#wn&wKTtS8R)31A0Io z(EDe7ja8C^n(v~J*p@nk$u4A5X?-rA)pVxlA1*T8Ff#evXe4q*S=_KmVLJ8FbEcsn zlQZ&6$>lZO7*Xc(*YpoCPs(Y?7-%yiIbGLuOJp@Sc}5IrIcM>V&?QG3R&%=)&@gEJI}6cx@cv z#ni%LmS~x}qK`)+@4x@P8&4I-t($ooW77cWPc@an%8P$ zT^(!Za>#`YbsAAdM-vguwv;uom~-}{PEP$Gm0i+}L+{k4eQzY9JVSD5dg@0_OOA7D z-AF|Cx%b8~X8I48GC&3g=!|bZ3FVsmAizMJZjt|_TQ`&NTj5ETFX*|8x@pPN8N;+f z=NC;}nN-#aUBvu)F0EVsKgyvhAzEQ)pD4tgf|}=J3{*xML_|qmOqp}|4%SW_+1`cJ zqFO7eX1<=a0`_@W0j8TvEN2a>7aAA#jeOob=Nw}ilF~^Ma`eBh0DK~tuC`_2GnBq? zY9J-V6y&0CmNN#?%|&e2wN!{ts5Qkv0aLM20}Cq38Fs}O=SON+juWTXI)If^CcjH8 zbD1u+i`7l0>vp?Ivca`(jXxT4h_CYor$TM1Gs>p4u;(=9@=IBbt{^0hrljn(8K}nZ zbI!FlzT6d-TAUk2HX$++{`hLgmqDabzI|tGF2A6UEouu(jHRxRP3Xo|Gru@?EzhnR zix9A5hRK%F<`I+2%cfF;0n`jG?3in-ql?$Aom<{Y)(s3s2^G>jE5bK}LS+BiWGOsc z4i6W7!|SLOESc8uZ~m&t040U$UjP@CY4_6mm27^;g{ zF1f`>+nuKVdjVF2ufpATCLV7a`|S(U#h&j!mBh$-@oRx#RS>1~BI}_E{rr?YWTMS; zGOZy4V?(?tEDMxxO`oCa3#WQQ@Ro0gCe2ez^9~8kK)p^7%uplWP}j?%m*_WO>;Ma~ zc?BGXl!9?1OzBX~(53D>gPUp28MSNbIiDJ5gSEL1o`uS>M`!iatY?Y_pBm6uO~=fq z89}9{T*3O7j$dZ1N*4-wLwDUNqe(RI(c`WT3a=sg;knl-B@2;DGur#1Ohf3o#uEab zgCu^ZtFlg9@riH+$N1hL4#kmm0!ZqWeqY9y#d)Xw%l>8YSH9cZFAARv2_I)LD+mT8 zG9`SLZ}fRf{Lr9tRTR#j-$qUnvg!Q7;_1RmAc}`qIsUeKwNKz=HV^e+oE~kA`T-_0 zz6Y=(R0Sz6H$WxU_BgVC?FXgEbLGf$RX-Y^j2yoI8qZY%!K&ox&e_{T4&rN0j+c?C zk)!v9xGb)r+?u?+3^Mdx6xdGmgMJF}t>K_SmEEYaMBPmWwg+|JZU1j#q@e(;2*vP- z{aJnC)}B8b`sL6sMsC+?P|x=V-P}v4xY$&+4NogzUSEVJx8!sRf|&Mh^}1z^OU(G# zAT>u_v$%I-5qEESkZcI!C-85)4So_F|^3@ZO&XiEdnp59km^L=X|_g-wePGnzWM$QW(c`wo{hk}@+Svs`{)7&gYd9i81fcb>p{zQZJUQ$A*|u* z*}xJ$TQBa6*mp)rlaf>5QP{$^W<51{Z`2CF zq=K)-4q=2P8z!7DGCb16#P#Ip@>SIQTJ&mj?^&%m@~l3XzL%wjLPFkARLDE)+!ww6 zrwDnuET|4YE&EV%l-aU`lEYn^fL;4$AS~K#BT$`%qp`&FM9A_RON%?|8WFLsy!QOF9bYB~%5N>vR) zl6;C4E|}R1C>Hj+1wfOrF7IhtCHZ~_EZr4+^hhESpoQ@dthba=5!^#4PUQU?9$5g%5_nFBTl9X55uQXlZl5g5eH;zefjWFx*KXyTklmW1zlSnZ1PGY#)FXCS71}`Y2U4Z<{)8 z&1ZGhxqzO=L4XyZLTB?RsyvJ;rRd&rbnm0+$iwK!y-W9#rRec;^mtJ|-bxMU$jCDG zQsatP$6d5>Z$i`7nnr$!rQNQ-Q8=>Iff44dV-{&2F?%1e&fC+sP8a0?j_nMtzP*I) z0~uBfF%&bCOOf+zSizb{Y=w=Crb8CY=F^UB!Gt7P_;LRjTg3)t*5kI_8hp7dHy_-x z_%6oW@SCprmi;Cstf$7Iduru+yQkKfWmfVU=0=C<=s*Fkf)|gE#ASa&->$h#{rW19 z!@_*HUhbd7wiz>8Dj=9!b{x!E%-@(D}^13mWR_%{q+yY ztcO!ki*t-CiKrzlEG}AYDg^@6l%Z0Pk@W#@tG_YpG_3E&%+^4og^P4olWPzKM_etD zHz@*jje;Cj~O}xUI9?wrllzsVz}%OB4f%O2^KhzPj?}gP(XOiC^N&1^_1?L**7NCLsc*PAI$i3Ue%v>>`tse= zrM{tZ-w-_L{$ZMNgk~I}8ArZ~wwI$jR^KW`50p2N*^ZQ=+f3dxRGV*Rl6w-(($ngz z(+2W^Jbj1-#E;GEg%_WNpiP5*js@50d`Q4{-a2Y)KDsJtw<4fCmolU$NsKOwOLaZR zvHaX;GXstE_ytaEj}}H4FqJJFepWnr_g`Yh^_u96cHDXCmZPAp$ZST>YHBXEpnI?- zNyM@Ja~oUOWAn(_=~Di#cy4zboYIH@=&k}r}+ z@)~6W7x^;*x2~-aJx}P&yl&g#^5uL!%g7A0KDaQlgo6kv4ItVqKHXJ@G*s$PQF zx3vZ%ooEH^FuTEs>$#k^f;HvBNY1j01l}T`5+Kgu^nwbp9f7o-&8jLa=;LVO=IQrr zF-<(kT{2=o^=BtqAKz1^T!_ymQT!-pFeL`R0XQ`z;;m@SsS?b()|YEe(DB!=-M4S2$e@;sSTJRv?p@7Vn{ zyxWl5~FtX#I56G?16;E zdJ^EfL{IR`f~I$oT5QBwg}=8;g8+JyU*#-1Y3LtyhjNW)%bZ%|JPeYY&iqa2<;bT;rD+e~XC?a(LmAMubo=6svzi zglAan`-4*KY&mwe7&u!AcH&ve|Ek#HtA0ftCrh!Za%`#?nEJhpL*C#Pw~Nd>Z%Aqr zyqi3~f513)i2%Rj|8tj*y@|nYUymD^g^|q<*CC>*uvo7dHq*q*uktj{R$I zl@HJZa#f;=Cp|kH_s&lh@c3;1F33wEp=;pd-PLfda4_+yS}O>=$_@MfZydZy@s2^8 zo#emPX@VfkcXR*~+ikWzNKC*PMW!5QBhHv)Z&+=fpKjTVM*#^QxUl(|JUMHpr{8G9 z^5&r`MlS=d3qsG`HqZ0BK@O1B(FlS}(a+6n6OK$k9-QEO%HM%)^OLKyF0_6dI%&sh zPP3n(NcNa<4-oNSQLOIh7V*Oo0z-n>zS2g272#Cz**;H%eMRr5A`BP3pNg=*=>1fL zT}AJwBJ3%8KNVrB=>1d!ZQ@8J)^YpwszkZ%$M|01*fJ z-T=nB#KVSc7dUjv`i~l8!GQT_LLgFZnXYbJVw1+ q!P;PX&&yA#;p;{#OvS<7#oZ?=yT@@T$}0$(51uRvz4nKT)c*!*O&Bf! diff --git a/struct2tensor/expression_impl/__pycache__/project.cpython-311.pyc b/struct2tensor/expression_impl/__pycache__/project.cpython-311.pyc deleted file mode 100644 index 526fab92df8b84f08ddcbccc9f70a4c4a101b9c1..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6338 zcma(VYi|=*_Ks)#94F2@kU|;EBXJTOLg@l2uWAz>yCtv-+iFeGU_4{uV0+vo+yl566B+IWJQxcC?o$Kv9gUYNe$}gr z^T3PUwl)x=1X}5m9z|RhVxga;*y@fW>8g?w4@ZzO6u(#Rdqe3Wut*W zgE*zbtf6M5ysjz|Of+(>1wOU9fk-HfO6Dm*%KJD$m<)u}+> z>eZ_OTva+IMWfL`R=p+HyY=Ox_4^18k4u+xabqre0l>)Ra8lF4#7GeF+)wB@^~evntQ2{po5-A%4EHm}1V6Qr#k=n42{xO?l&SprJXax;c34^@pk z%_dxRGr*f4{%993{|(?z;461X5`1LYSKy|}Ll}Duf4eQY3NCui0F!An#Z_wQHAV?B z+)cdTdI&U+9Xf?CId<5jB193ENx`-Lo$b#)1kN8j+_%qt%o(_sdi#VFxOQr=cXxan zuwmPI)6797l5_zz20CoG0T^{0Y?-PV?WU?k9SUPEZm=19l{#7Y253JE#--;Ib71s9 zlS($t$uS}+rNc^~#n-xb6;JBwul}V3nQMRYO)(J)vfmc-DyUmMgm7 z;&LYta3adWPl2Op@Un!K$sJ#OtucNyt${xD?Kpz>h^z?m1;&E8Qhy zZFp*RcN0{Ebg|Bz9C7pA8vF2z|Zb;3JCPq0l+r* z&e+$JonPpWCO^CR`Auht4gv{;tH8S$$Y&_ksO5o3pNSh5pH{P$Ft4SuR*!taP9!sq zMMZd~S&A9b>Zr>@X7+_(PMR;EjPD*h4EKT>v$FcGx*h6ng#Ve=b69*|aT z%{xA12ryZ1?Vhp3705KXa)O1_w+gO#CnPFdGo>%_1+Kt1`=yrk8tW1Wnx&R-+ZSZJICA%7frFsP&r-CS?~v<=#}X47M5hF%AKelaH5rHGsH75*qtJV-Kr*GK6`A24 zNLS#>mPgMgld0Pg!PYdS8O*4(4^~8asvELys5u?km!L8Gw1(vCRLkdVdyua!$W2PR z4=6F}=vM#~$+zNwDTco|SP@4|aiqAnDt6Tf&+lWP7_WMLUQ-&;CVoHaSHKzipVhP4VPPIDq#GIOm%Sjl_Q0ptM`Ci8H&+jR(Qv>WgUM-Rj8D~=w4 zn~smvW6uCqBD$8R36&i^>i@+>$|9^6+6M$KTf z?1@&zo_mvbr^-VA#tiHj06IS$@iCXVW-n3PEv!LhRK!SzOl7OrKZ8o04b;d4c0D;alM9+ zof9Ck1_#4uA!js!jnE$#V!n=huQeR)afA zfl6@942~6Ns@`C6soK$f=TbG)d$+I_+OrzkQ<|@YCe6@f@l4elDtiaPP^XgV>n@gd6IDzQNVI*;yQlNi0**ob^q8WCPVL;HG{W zdM5VWyL4Y(>yNDVM=JeeX8%}495==BvM_FYJS`T=Ld?(mL5ZZC8!mc)fDv-JL6I;7%Z>-RY<28 z@htBJ`}-69HTHySnwF**Jm_iGzX%04J%@eN3jG}pcGS~+LQSV-8L}fg0b-auX{UP$ zNfDsI>Ps*KX}$5Qx+co5iJB1Phn{zCF*|qF+Jp^=&&I|0%&v@-nZgxNi6FA^Y!j+?(a(_E1)YlA+H? zIJ^8Jux$R|`Vu?ZD zLl8!QS(#OgO#5~Y9mFv#{D1NtwQc)2y$W5l7yT`OHJ;op#`yGBu0Pn4$zzd7CBILG+OT`q*Y2I0gLS0LE^1zqUX>hsFRV1`rTn;g9~&>jc;>&|=TI zLk>Bz++@+m_2}^2*SY7Od+s^s9{)=$#uF%Ot>5SVZIqDz!b+#m9;iJ1SB8)eh)gm> zW@KN1Df%+LBAa20T!v$?jV<_#flPqn9N@uBkm7#8Lzxi81AvD!VTuO<=QBLTLxo5& znu$_8T!>}jGOq;J`Z`R=5hc9V@7AIU*W26Mm-FrF3?sRK9J|S72L33mILsdGOp&9# zbaNrO59Sa33G@45{;-cM5qaPykq3eLNN25&Swil@pHniU|F2cP+5SWF5Uih&!+IWiR-+AX9UQ<;=I8A3y;i{O-!olOIaUA}GWJ*9p z#C67#bS_UBP;?-U9CTL|C{<%sGAKJzNl?qkwp7^0g+nbRTwBd&SNS|pDqNFp=mN5@ zl#G-hm1H5SmJBHmdp1@Tp;ReeRWw1(3BZhy$BmUG&B$jf1*qfZy89C6U}qqYZ?-WQ z_~+Xf68oBM!EOz$q6qmMZcI}I3I5a)J;b64czc6%PrpSqOR{$8Qd^LA=(6%|MJZ(! z+kd4}E+|RH9>_`sl!Kv&j=Xhfg6)mfXIM@zoic!%@SzlZE@h1R(cpIKx0Pq1ZKv_GbMrw?KeeJU5&ifcz z?sl4q_eO!vGu75&XGiq#9UafXnNwwU)I&E=Sb0ML#RpB3N^0pwQLX5p(^m^{J3t`> z;jFfzzXZPmRn3Hz?C^3^Eehi{+TnEIt*O?P1@Tq?dr2{NQVpsMZM#Ud zz!COeQ5R6xyc>6P2}OYB3zeX)p*_5PjryR9Xn5R7Mniv^WV>php=;7nn6R>jKGEjz z?6d@+r?77i064L+_K9u&q!FIC!t?juwZf-N-zjYv`aP!;ga&##B|x4%osWT&yY6Lv z8?beDo%39E*9_Fy8YeS1iNOJ>c3E>Z0nQmT;Vxg+G!e4X8)jXEAYguK+fLWc=74)Y*@1c2XAa-_{>enZX z?U#)kWhKekJ|kxbS1{I4G&=;|SU04SVf*tXWBwSp3+B3x{5U%IyU;5Isd!bE&QwR> z+|sWURPajr89^WBd;IQ#22|MUCi>~dpv*#bCz$r`^@ zCsZSo9JM)G1mur< z5g&WNxidh?EyyO8r+)zOb2vDKWEcp$eKM0_WuL;Wxmqr>y|o;7i{u!YyBW;{WWN&Z z?dw{-oKFrwP#6jmxnB;!sPI}R#3$WRa`<`kAxw>|q3fUp0`nM9hCf~5U;V&=6PCw&Iix)t4Y&tyk?Txm2$X5Sj$Z}~blwd|tN(!Yj`>@mk!jcQQ0wfS zAV8D^dZv_Q1^lLpu2ypIJ=g@thof0^FbGf}j4=d#Rjoj@3&YXQwi|oK!vs6vOtE90 zsU8D}#tMzp(;9&H8SgFhMlYN0o$Lq`f8iePXz;xEj)4fNMnrdH%?Guvnvl!n6W{WM zWWe_8l`>vJQOuX}hA7heprI$WgJOF-@ix*H1zccGR|+|85=LoB92yqIY(dg>QPgos zxF2=uFSKc>RTJJ5NO#tfp;JEw;D74mF&F;n(#@qu-0(wg`1W@{TKMo{gF9q#hfMBJ zb7)un+|AdTyC>@BEG}WvkKPjv^~_`A(0Ne#?g!)pQX^~d>lDCUiLUrS0?aZ5J!$CF zW&nJ`Xc#i|SRQ*6SIKdT85*b{>W+>*zJ>q#5o1#cLY}t=-dRByDnO%KWk& zbZF~#R919&cfs2WY@r=&$J%yS&&vuNOis#;+|HdLQ8I&71vV+MvCnMI@d8m&|+YS*e*5y#~{$$Lh+*^wGw^U`oLCAE;utuS6H-hg;L;-NXv;kOq0 zhTKOz;hL*=wP-CO`@L&{K@TifFSM;CeIE+Nica4@lJJ@$<(5@hxk4Won2C^t#npU4 z?lz``s})1IM&q~Y-+HXB^W5wqw{5CJ?@f4$?o3YG4LkoWD)xospbeP>$QMAAl2tVH zj&|J3o!z&Myv=_PpS$NZP1S4`o=G-`!H*sHu4qRcPXd=*R<6^)x*d_gyIlwCyefft zrsg;T76{&YqMR@4Hm_$T1EMFfB-=qqZQwB5uSum9B^jf3Oq&NbwG#+VBEWdvj`oV$ z9@;@*+9Hk{fTT#<+KXjPHB>N4m9h-+8X&-GJ_|^|dC=Sa7$x;?Lr2vHAJ8q%<00&3 z6RKm|EU4R4gsJ))AoeSS=zj}yB;I%Wz-8+(pfdyX{{$F0P1h~38Kp8&-ZNcIGh zJ<$r0VZj0M4rmWG_XAAc)n(@<(_-QMCy1ujp z{iTh!?;L5wr>*#OeF>sPX!^m%iOr>7z4>XzJo)Nx&i>|TWBQUceW?+9-HN?l$B1%l z_j}h(;mE%aem-U<78;2KE3sg5!&{L-EAoOlK6mfJgY%8Zf)#<*1t+j{Ej7p)aq34f z)dHDZ<}%Ty0R4n$=#_gtVWM+D;Go33E`nbw8F}MI_3(2Ja@)9nM0#k+!52pQZyl%1 z*c>`lPtu2jA(W*QAl~2zQa-LR0k;Bp_oWDoAB73v&R>BaZNkCPoon9((9UN#TIlAq zkVVLNNrkHLM>au6{1Xy!TsbPD{LykFGG#@kOm51FL+I+m7)If1TofVaA{JG-Qb0T^ zinvU>Cn$=tnuSOU#bNX3-B6YWOBzZ*JC1k$rSsYxrm@NSi{W{{^dWvK6K^o(%5ZnD}tU8hEiyi68Q2a244c~g8|&aH>_|P3>CZO+jtZbKJ)u6HAJydz5G;midS7qw zBQxIM;Aar9?;P^iwvDXMc^4hn7H#?5OKZNgDcu`^! zBLM%8Qj`qsJJ`u%r{fy!FydK^Uo;FZ?J#7^q0d;==@qmSUPAEZ5*3Udo=r&V&0$bvPp(BKy>e$nI>oBewp^&fuN zf4I>`N3rpTvGL6_jo6$Oo2yT5g%eizg=T=38JF%p_nf>rM*24I=Rm55p}Mnwp1at-~)< zl=k$1$$2K#BF}=`nfJse*#2d=Q)Bv=p(f<-AV&ryk0@&FAkgB{FKt1RJ$Tppy43vi}``X2WCb_#t#fIqKFwN1^`aeF=_w+ diff --git a/struct2tensor/ops/__pycache__/file_descriptor_set.cpython-311.pyc b/struct2tensor/ops/__pycache__/file_descriptor_set.cpython-311.pyc deleted file mode 100644 index 43cbfba9d5ace85d4695d9809630941ef5a61762..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6135 zcma)A|8LvY5q}i*g`zCUj_pKtotUxXFR~LkPMWN6-K1$8r{9*QO}YUF28|JETP`hX zk5p0{GAa}W<_2!l0RxiXW(-@KqT7%k_6O`wkOBe`2pBL{?1%lvAS+N57_hr1iy~#G z9YG!6Bk%6rdw1`1_w-dfE)uv_oc}I<+e65=*lAvoy~eZ8IYRCeg)9(-Qvw>d8dwOd z@(VnNeO?Q$h89As9n`|Bk%b6rhoCJi2+$6z!b-HMREelO^x3HRuCMulD1sVW3@g!3 zf(spr$hcf7A6LMuuk~dysKh`vu5|o}ERKDER*;Q@Y(hz>&(RLnMOKncV}U;j{GCdt zcP>g7NO!d?0-5KH7gi2|{Gp$e_l_A>xA6ADz49NxDNCxEUCe80R#A-{%@<9b zW({~(q`Ik(rzQAqNU}5!EORc$4b_wuslEyyT4EVdoX$vZm<~&4DKhOF*pLjpA*{+& z%d51hq&Z_lQuM-@3C5sWQIeQqNdr@CE8#~$NTDzk$#fH^(n5t%p7IJ}RP{Nbn`#l?n#(h#IEAGcR z@~h}Zpd2W3w>iqT?hUdfDgg*4e*M+=m@Y%QNlU7fUtIJrkdFQ4dP_0|mV%EXGk-J9 zQQTN*o(k|_ZJZ7I1Pop?&@}MP55JsR)>qZ3qOw||R9>5!Q;iQyy*PD8rym$aIj2q; zCN1U6Gp1TFblTR_i^f#jvS7q#ifeXX77T-(us0$&HoGhr6iroV0>&PMpK%6Ag*ai- zd&25D?T`T9%}8aX9_e~8P>qbhU5_8ye1G$OhbO}5ZZy8x^9-bj-Pty-k%JxRUI>;r z#bxW3*9 zmV@+mnWrX<6jFj4JoJ~#!EKMumd3Wx!=?*?Fj}}A@av|h%G|cco|c|+*l!s}%7N{0 z>wBxUpA%AjuMhlLP_32UX3eH%>~8{e~a_*a^5=kQizKAZmHT1O%0& zD{^kRT@v>S0tZ!%l4j-$ON?Dw)HO}NV@P>`Xm8MFK+Pq(WXvFvFYXv>&#vI8nCY*ud0A_+=m8oOEeDj2Vqu|S8pluh4qmvTUFmi&dh3hPS%VI&}gu| z0zw7?O@%zV^!ND+p%@yVmj6Iqn_us_ybtKtB%h_YPne|x(6~>QVAlaYLOiq(^LK$Q z@do*tP~-y`4`fc;Q5OUZ+>6ql`Qj`(C8w_zFRYLK!~~j*KL>fEA4r8%g}KK4;Em0` ztz&MtF{!A{LrmIXKtssBb`XcM`J!HAxMqi0x>9UjRuo%oIS`~wnvPI(1ttrOK3i-& z5i@N8%7i888+JccSLJ*GVza%RFzhfUPry)IYy(ja!xAp$H@o|Q$d5`VRs*^ zlO(YbHj>jxhQ_PGW6lv09rhaKq)1qD2+v>T{w4NQdnf;zt2aEFz6`#s33HY(_mteKyx?4G&%9Wb5d(yXqSfXsn5?YqiBKk~B7ad3{= z$GJ6K4m`TXkeXYczshP_s4%^t*@QzcR5G+!ELc7Oj^QZ zRhR@d_mX#$58%zPB@9=E;qRRoiFPsFP6rW>!k%kHU{Cwh8N?xGH!qwoi@wjH;g7x* zm#-~$xH5r=$zUaVaI66WyRmg}EcrFI$JN%?;wHGM$4Z~fFu2`^)Ekjr0Z;ua4$G~ zxM0%mq%sFVummM=^FIcVT5*A+%=`K2yIx%@E^HiJdir#_`rCE-M?e5;S-TM`hspss zCiOy)9z3^@U!Egd%Vj)DaqAZw0QI~w66!t*5!rGC)RBs$eu$@?yk0OeX`z8k42sUd zx=@_2-H|u4E*BgibzQT&=ko>SN>lW;wQJx;`!IqImG88;nR@~7eum0kLW0q0i!I;V zNmn&QCPajt@V?Mg1LJ`~+g{$67`8npM`1M6OCW)@l<@#aBaV}wp4&PHAsi5XJMr!~p-b<#>rKS^DA zoVv0jCM7-PYaYfA0xk@n zq6|3%l3B}vaE!pCODn3__GU+tn{f2l2Vc@M90d@xJR6 zfqzsMA$r$5(TiqytUAep8(C1L0m?w-8fe5b`riLP@UnH>X&Y;XjfF@h!GA2bR5o)} z!E*;5ja=xOX>{NZnl`RB{J$(?ZDjySfH$UOY$AF*2EZ7cCM#A56ktLv_DPx(OP9cUwc%JbrNXd}gH@jAy6 z?X`4|9m!@DJ(tbeVVEqqmSPA}1VlUT`#Yoo2FWE|U((bJ%VM`mi!4h~#5y}M59ZIg zf7aNdryaZnsKE{mo&`CJBVp2^Cy?N_r8rTF|6?)(n<|R*ZFYx;fj^HYE6JBbn62rwim3C3)a9zhlf6U=^^iN9o*6S zvE$nl*3iq=u~~LEj;M0`u;y_YZB)Si`ec|0@m;{oPQv6&i6u z(8MM2lAbuQ_YPXU6All}-LAd|!@oc3gs?q4y1j0Vy=o0#Vtt3eOAnDcx$9h07Tc~c T>kLm@!<49&|XSt>)1`9RYUC4QW#|I*p;l+Zf91c zHYupZ&>VYnjy<_4{y%z_gP>x-5a=nFB4eKdg}yh^DwZt@9ql~*=DjyF?{j`jr4j^2 zsrAqH`XC{H;UXRb9dP~!fJ5RDmw42(%G9;YzwJ8(ixPjp9i%9+%R@eMnZZtZ*iX0# zg9pkZe$q`EJP0`DrVJheJnD`b%m9zMV}OSbhZkHFx$hO z5R+bV*TR(?w+LuXbwuO{z;V1FaH_RoW;(ThZ&Rp9&2I;CUPZE6h*ly|QGv_^p~`F* z%Yu7C6=X4t0?Cz#GOzKRT$*Z@-%{lWb_1U5qY3N@PlTf4iAuy#k%wC=$Te7ej~5k# zN{5>?Tt6Rn)-2vp=o$`y!bF$PAv>4^x96>fr^5&cDw}CIRq32hn%d~Bp`KCOj8d1oLyn+Uj!nIw=m+|g{2JGITe(A@(h^4743GBiq6BuK z&(}uJE^D?cFR>U#lQAM!D%7+0L3H&dk3fK&+-r1EqTTxoy?Mln$;M@r31tuk>4Zd& z^&IA`UsjElL=SXgqq_B#P%6J8bmGC<#_HoQ9&LX8SP$ow5;FQyRJ4Of=#e&K?YnBe zter~U7bgE{OFK%GxAmyG2%3eP(-qn}KQB{w3A7h9!hc%BRnZTj@nTQ*ANDPY%>O^> z4p^Wtv^8=%z1WyuKV?%5_HKhMz4+|RrZe|02#{LtmBW6{{gnIVeq-{z7xctgYB)=E zXQ}D1Kb@(U&Qz1lG}z7OpElU-2D@|0Zl18Y2AiwTub!~gvjJ;hjb0F>t!vaYMK0}V z#yLigbDiMa54O0ttSFuQB;4cO)hrM9FxHuN+TqE{Nby7#BN3Mdu=G&O zOTcE6n!MCWd?QVwij^ISU8u)nl$-XPeeo WZSi3F$LvA2W_~T3IPZYKl)}Hj{SNp5 diff --git a/struct2tensor/ops/__pycache__/gen_decode_proto_map_op.cpython-311.pyc b/struct2tensor/ops/__pycache__/gen_decode_proto_map_op.cpython-311.pyc deleted file mode 100644 index 992e7c353646f983d39e8a14b93644119dbd76fd..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 695 zcmZ`$v2GJV5S_id*iJ-aMKsfy}-y*MHDa&y0oB(_@XD0I~sETU8R>p#w zLCb&9GGBTP?aKK68)Qle+=?V_Q7%Hs1yv}uqzO}0s!l39-|e$DDelCQaT*`JLxv-R z*HEN#h@qf1AaLCLPlcE5m7+>1mm#IarWS)#Vie1hjvDEn9CRXNi8zsXjAm0%v~Fjp zm}x^XLC+?v{U#z+$)#976q+igd(;r;Ota`H7De;cJc;B(6R0l+ipJs%eZPAN*LLum zw|-A1!dE;QBC?VHT%^ZZD&G=L6$^!*Y8-}ocSd^wNtODCV(`z+sL@UB$&AxJ)=e+_ zm1%0s)Fp&W?*ZA#;K3E_W%D|LZZ@x1t?q}1d8<3|wohN>-gef0_Srkgy@MYv*?d8M fg0&emrxjS=$s3O|s82Q?ezw6n?&QlTZV#iYKQD=t`mbch_M4o4`BDsdp0 zzmRy^`+tDf@twNOm{0@8>W5ZNjKVBeCw)D&BQ@z~$P0O*@EBcZpx*q~L+9K&f(26V z5Z&@nx}j1(7&@+++Yz(smV0)T%euE(O`$~T&ZCv{FL3P(yg5ils*s5mMTy9#$-Xql zR_VkuGnx-%Vl0*e`((jy1u1Qkp`5ICz}0xud9ld+5+j$#|Bo5EWgZej=377>R}f#p z(`s2S;aRn;R~x$@w`Lo=mvras&5Z6;@v95EH=}#MLbAC}euG~t=*?SjKc3xteir$G HU6<<*kom&e diff --git a/struct2tensor/ops/__pycache__/struct2tensor_ops.cpython-311.pyc b/struct2tensor/ops/__pycache__/struct2tensor_ops.cpython-311.pyc deleted file mode 100644 index f7d3b4b0a7cfcee05c6a2e24e397f430e01816d6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 14312 zcmeHOTWlQHd7iy;Z=}fOow{&DNu-vfmJ(&_7F{e^qO8Qa)RfzVvbvhx8InuxEipp$G~UNL3k90X^7&f%YK}eQ4j7#-v#cKvY1HhrF#|8z_Pn?f0KE zvpch-nz#+x01dr5bFTk6=Rg1b9RGEDdqlwBQvKiO-`DPFl6_V2b4Z}5VJa9H)oZHlsm2;AwEi8y7RVipiRZISY5R>qmkt?NBrI1m~0BcHkOn*tz7xC?$&*qertY{fEThfbaN>g-H_b(^~)MkpZk}9c1 zy_m{NrBt!BWs_EtR82wCAdSLQ+0;@oTS!TTl~lGMXETb1#$I=0OM_Z2q;g7OL0?Rz z73@>Ff+3OSy3A0fR7&Y9CB+P?ie6R=8IMg2_!9F%qzGS3gUvn?bVt)tqu)}41@&OF z!;;_5zhqRP(e518$iHM%+m@q3jaoZ{uVaN_yV|uvCQzy6C38i$92INS+8KNubG6&m zuDLoIwRXlmYMWi{8f6Rim#wgC%Td1y^YJZ5bvJ7548D$i_1M*}6*fj?BdeZ@ry?{N z4s7PR_uA{l!fi#>HBk~lY9f)Em@SAiStTc*Z}5kx7e#$h5v3B?qNHXeT@hvKn@seo zggbLv%Ludm^=sFzf9KkT)TJxe4^Nw3eclXY3p%pSmLspQT6{X}{%`XE!7}gdsz*Oo z@9cm6>f7huykNG?=Zcb!h6ebVH?N+(Zq;Mjm_5O{t5@GN{pn&c*Q~p6(Z5rW>C{r$QyH;J@WE}+1axfFC@Gw(SzBR%cfOHUD1fcMDdqzPc0Vn%2Y|tmsM3- znF9Wn^jx1Eu~^fbJsxdXvxxwIgnaV^b|vPYF;hoQ?^V{ z_ha5+{L}s%$(rz0P-yEif}#;je(L=sv=QvwT$uVa{>h--vZ*j+w-A7cT{!SJ>eD;U z!d<)I0&wutkxvr!-e|jjZLB^ZM7r(|7~!F+cjybOPB3GsYa}pbMkr=vo+vC;fZ#Se zdEO!R0=K1H8FCxqQn_me)7b*hJY#lp#HEV!Ts&*$w&ox%COh(qrb!D*3eq$$>1I4F zWtMS>DWE>ZhnvHT#R6nLlkkUuWv(xv1E6K}Y(~p8WFOW|vWv;sE6Bij^SmciJWI&) zzw^Q)WS-bXVKtyrZ|kpuzLh}5Tk%!=PrZ*IJ)blrDEQAOg-~Q@0R!f8VsW0a8N&x^ zbZ})(xvj7UNwl_>j6_8766>jgS&4Po0P#*%Uqr#`#&~&B%<9yI=_*;2s57VPq|SJbJ} ziVmWgS|}D5a^P^X4X8_V6<}00g`50Nh?mF`fVkB7!7`$>kJ+Bt+UyBBsA54uv$fhW z9mpLMVgHPA78v9B*i6%e5{^b2uc`twNmD@3u$TxvY7vK}DG=Vd>^&SUo1McGK?|6h z(TN5YCDbh|O34D5x}8<-u;U>({5(Z)XjW#pC&JAvW^*#kGHc-$9kVHwyR;A@4A%H7 zcn?UJ7q6>jWm1$1G8H7g(!=;<^RNhS13?zcTC+8&zgX7UUisED3iiu+THK#aDoK6@ zBzS1A#4gqV3($xSc!Y310u(AZI(bCQNoggQpg}6xlpqO)BQOeQEXHPxBc$9NX+^W( zElOCdM5rn&3{{GhS%i8mVVM|+!?Bq)NepfG#5BZM*7<=XEO2+SD$qh_AS?4yIj38Y z_a9`-9Z7Je#Hz|TL^|5Hn9XQ0VlJ!cM0zB1v3#+1_N16ApG|VmirsXsPrHoAE{y`b`FaVQjOYpov)%oC2Dd z%511(348@iaR5L8n}G_M6tyxiuP;KRk~={pn*}PF9&W0W8?5t)G-YIW#lBCJ%nB9^u0+&v8nlT zE|;=}71J@vleuC>%4w&Ov3uXbXpNlnHKE%18vi{E-0%Hxs8LiCXof{}7$5Z&mN$cZ z31&B+%Qa0Rs2-%QhbTd$L}-pQqfI){G*XmAOrpp(VGg`(;p!RmJcWPSB_wM?ea0hn zo%LXakB5!K8GP3}Uo$#iTbtMj@2Ca)YQf%Gu&WmAsSQrlMlaOcgOQ7#7XqbUva~J) zeHT4!W0{8Y1khv_(7Al^2npP+hKuDi#4DahsDILEQx8>yrVF-2uA$Y9x(i2vVyy3K(rU^rgJM- z$SCV0qf?dcG{wc9IjA!m@+;arH8*l^@^}vV`7^w?>-;r8$$cPRb(Y-&59{W6cY$FsiLVF&5b3HU=gr;8jJh4Mx3P|dL zCvd196k^+rXyWm>5q-7lf0dn?dK#cJ4TPXTo+3NA4jktXJRb=aujXA4W?ks6c-?4S z_5RS)^siQfdgC~-gT^ub(1US++6;89wpH4;K3pOjhou;L-!t!lJsGNaH(8Nk*$X=} zys18-(_XAU$D#Y(Ygcen0IEq2!M)nU8( zE1^nIZ}^Y`@aCRJ@3q&mxvdg`)f;#kcm(hMNyEc`2Wwi5uEsELyWR*a(&$RK60O7@ zfi0diSmV5K^T0bOt#+(#TkTx!s)1Gyj+NG#JTcP?J5t3~I7I{|q(&GpXaN9;eCmLoZA%mo|ry-90O)M=>bjVb)8+ zryfU1C5wmIJ{lPVqPWKF9`ze2+#B1%@ZHM3{!4V$$a4Nyn5}}Z-``B0M|!^U32i$b zTzx#Z9z13QkJWvy4%|L%*qH!A1`6Ba5oD|l_PI9*|DiF$tq`%%etV-{7FlFLPS|&B zx03+rwC}#&KL3hFlXJkf8K-Sp$@9n%ZF9IfJY@_YTMr&Lg2%B@XBTeNO=0xC*WI#Q z%$MQP!r!&SBRpQPIXuEb8*1@(R#xOhn;CB!q$yl7DJpJ@Vm4@Y+Y%|1^C_h>$954TCCpZqt|cQ0EqgNWEG1z8=Z zRmB~i7um=y?w5L*Aa^)$14~GKlPbGf?RZVSO7%hRjhk(7An3x2ZmBc`CbmL~M)w-x zEpq6bf1GQ`eN&kahm7_Ypi$z8H#|b57lyvCqZVynyM*|5SO5KY*E{zZoqN_MYF(ny zHNG~v(LHK(pNCg>|6RBTJ$vCE^z2_a9yFKW_9NuX=mxeis8VUVA>U zryfA{3*x>nDXDt9>mi!}pYNJNbDWyv)ExJ(O@0*;f}!``smAxN2lg3(eRUz=+fj=S zKKS)Y|C#WA#2E*0DCa@ygN9qP1g} zYR69hr0qq(`|2h1qO>lAd_yc*4<1}ws>VkiUU>W}+@?VE2baHpc_Y+Q>)&15d7u{W z$FCmthX-E>NWP@xg)a~o`cgoGE!s(k5b0byzfmHD znVd|g3OB_&iwffN$ciUUiSJ-+3>1gMPWqgZdQYRa}z~ID{6)s{G8;NA_{0aF4W8zhtI($0&3bd!W`H{$F}w&&>9gM z(*ZnGv^%3xoIPc!F^*2Uz)RUjt5V45pJ*tlBar)$uxsjZWqs#KW9LcSFmA*;K6vAP zuDbVRwfEFd&Q@bH>#-RlHnTRi5srQ^TybzE;MMTb!VmV_8va~DE zhlm4b1XAm6MrfK(bhqawauQ_{J;gXHW?+~gKwZQYN+a+riD|Vc$r;=zp#esR^TJKW z#ZL>(GD*gf!9^B{FbZgsVB^A%XpV#oV+0=4mJozy0_#Q=F=$4tX`-mF4Jd5^!j}L6 zl1uh++dd^hZjxw8r7eNSFZ~WB#Tmread)Y%h|3VLcJ~|02(F(J0hEuTGF#+ylHZW9 z5!p#5k8veJA@x(~8%v0$lTfAv`&B@v?9IY5M*xEWisC>e06mrE_eYIcPCAIP{2(CakRqD1Bpj z>?W-X3LpkrGrOAahya#Ms*k-)0jSGpb{XJ#RH&5vf?6n)2EGt&HSp~0diaPDK2r4_ zVI=+=qwo8_7JX~ww5}?2Q};4DpT%WfPGUm#)}~G0n#2)jloAPjP`zA))}|{luEE7y z(k+JkTWNe%y4%g;Mqe)$aMha2tpL>|Q8e7VBuOdYBAH}4F2%%6Gm+M}qcmTi{}_?{ z8tnBu?r{)3*z-5fJroupBX z?34Z$$Kce(c3jx`4C0#gmIF^rOClXIo}9J_0{i&UV%cWyrxxs5Ia4g?QnsKuVcQ`2 z(-0B1K>C0cX0T$oHt3a9(?TVwzl9BG$GA1b5^VoSvR@k5aA@(g6N_BkKuMuPLoLCS zm9S8|8^%xzUN{sB3pPNYij3H$1MtGok?T%-gp&_=6rMdMaoqCTi^T|r)D`F6IWPzN zoWN_770WM!JQq1rk<3PHCs!;k!&ohsAlBKMXgSvVCtBK=y<`T1hLiQafX72klL3;S zADQ)T31FFh%w9y<-7(Hm`UtLM^HR<|3tuC_r>BF2K?{Dd0F^UxaD=I(kwf@0l|;q{ z-oC&|=WadTXXj@i~e_LBbFqOvLhvoe zvD^Xz3cQuQ7Ge9v(*;%28KUTlLTu1Ik4VYZ1C4L6qtft}8n1ZV9f__uVPZmN<1lEn z5^Cb!643|?sF^;jDAA$*0YsU4krMJl)a#UxWHP&5p)exR+tVuDE~~#!$$m=cf|JGH z)VC-%OUbv9B*JEtMp%4r4%mfOepJ6pLsUvATx)ha(3FxiL=h2j3$dlA=4I|@ceuGv zqty~ous`nkY-3c|d~2@0i$3_@GX(7)k%2vTOc>nx;Kuqu!Wc+Yy}h;e-A4PtXZ=R| z@wNRw545?WP(DPVHbt6#zJ1!*K7;oT)$WPvfitz>kP+Nl9X-Ag3|AvN*Mqx^;I5yF z$37b{j(>Aq{FWhpYa=*l1drAtyB_vFZmWt15mRfMhQri0{m;{{e?I;C`t${3`ofE7 zU;5i_WH zSS7-F3NM83@Nnuj{*W)p#1eMkEgyv@Jv*Nam4F9$E(7z|N8e_*j07^YQmwa`&ScQt-60()f}w~u3t?!Ty_7p zs_Cx^*6+D6T6O>G!4A*%hgx00=dt?qF3Yqg2Zhd|_wUvG@%y=luUC8b*ZTH9-v8`` zarm+^ex=%bwbr-uf%>DBYH#B4#bw6d)D*9=j%KNVx)8}R*wtY_tcWp zwaG*E=(y)1f}oVv`UXDC)_s(H-aBmcj@JXMF(~v7)Z17#Bn<9;82|C0F)&dNvzmx7 zI&J}p2uII+Rxyse=>*C_Tj`u9=sEKEgfa0NUJ!YX)J79^AHJ;MNQRorYk$k*>8uNs zTFaxXy*$jE+l)gQ00Olz{E({tw8xlb>=O~{5 z89o6rbaK?`XlpXs?K$({tskKgna9&lkF(5Y@xL2nnbzcpFm;%YVuG+g0T$W5&lsAb ulbooBY^327C+aRlQAJD$cD%3LKlkA!Mo&HaA*_r5P)Cgn+7=1p+5ZH^Ya|Z< diff --git a/struct2tensor/ops/__pycache__/struct2tensor_ops_test.cpython-311-pytest-8.3.3.pyc b/struct2tensor/ops/__pycache__/struct2tensor_ops_test.cpython-311-pytest-8.3.3.pyc deleted file mode 100644 index 724d8e9a09067df227791df601541bab64b68d12..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 49897 zcmeHw3ve9CdEV^q?C$L1{eBQn5+n$a1POp|9$z5AH*pVAmPDPjx4~CH{!n(@aQ4IOF@MAhJ}&Pr4aMu;SOI4r^E9u zs$TEM~CeEst3Gre7?Be>wB1V6c)w_^{BbK@Q<81iuu& z+KA6W$ft%}=2NW^F0La?=~Q+2t+(Z(=LH|aIT*ZDD>l5wU0NqL0@jI5fc3Ges|`kl zd%2UG*!&tNu8)PUHX6RMimOeA+wy&f=eTdS?ou<#vEePxrSCsme@wI^%~j+$4Mn#E*84F$tZ%eJorl93+%sb^JK@xSUeGLx;!Q|#S_xR zK%y%V8;y@i9b@D14u?;Kd>#yF|E1{g zL`)0BhepPSV{Ji+YJh zZTQVh+I%)&%GD9gEQaQfjf^L*YeDuC>C(deV~Hyf-7`Eo78!dfCJjf&fzB+oR0?!j zs9ze3iYR%)krfqSfGX9vs4}QlAAJ)5B+2s{H^m+2oXu_8ldz?^YPu7YI(=OwEbN-{ zyn`6uHe*epq1G6I+*PZK;gf!wHFgS|_6~8=zA5i`*Ob6SG_l@E^zO1@1nt9K--TFqxoMP zyO5BEMhD;ZYTk|=ntz;HZ=4F+)b!I!+pmm`#I}!%BNLJoy}tc;EPgF9HopDkG3i=- zJUS5D?&vrfPw0P6AL@}2YGWPa*MAy*YIq2(ZER%x(B!TNdL)|N<%!|p$Ve<6j}FEn z!?BlQ!;xq_A`T5CI-VLH8;B0a55Zw1dj^G!w*bs=a#4?d&&6i_i^Z)A#jQ$ln_Apv z20WtLVa>2DAW=c!Xd9l28p}bWoZ&SYPth<$Y~l zR(#er+dNVSy#|rSbgOQP48{^}z1)2HLx?O=Dx#q?+`UlU|ENDSKdIE8QR~krp|fh} ztn59jzXD>Imr=||x8d~$7vnB@z;J@m1e4?wy@0&v1N4hLU_kT(2E_oNAO-!DGtVE?%X@zK~#KJs19=}KfusN0hV*&ozr{T~WzL@9yI96{FJ%gUx zUgrl%nxlMe*Ve>svZ2WEi6DwIj3=Vp43v0(W$bTs>5wvQ;8krN+x6@aG z>N|b?TAnkP<(`gnv!2g(Onbyo(w2vTiQ{hMU53!{nlRV5Kp|*yoVbFS0EOe{%DN zD7{rk&T5@kiLGPTl<*Gv#cJe#vr9&Rdxk{z240~;}# zS=UrxDty(fS*!T)C3ciuVp)H&DXwiWL$u1+&__b`h50avD zlTC1Z-IL}g!_%H$Lo>8q3w&eV1x}(!41mXOfm`zoX`#zfbZ`V2#^dxAnwn-9`y#dC zVe){Ju5~a`2t#pX>~dsG5@V9iAjREyyO)32{x<2byfQv}WcXnt00dSq5zelc-Q32Oy^)6OtQ3qqZ1=pQF(e7CVIyqOX`R;&t>U6@{AKpMa*ydI-)bd z)wx>CWa}!M)-WMI2J(hx^kbp;W?`~7xknbZD8d$1*s>@bTo4W_!eLc7d~X$@ep%S0 z2%A)4)1t6AW ztAK@j)WSV8vHL}pOI#pW%)pK2R9^n9=S#b9>`rZHo7*j~?~=ofsc_MajX49B{9JzN zQkbiM#v|7~O?RN?er;p2{8oLcrXy8VKYQg~S;gynZ}q%=SuWeAlxVWSuHB;69!?dN-x&P*E3f~uTD9X&3u3O@iv;5;A#xDid|7cb^FF61|+id&M;e)D;?xa-c*@0?5(A6YCu zxlnvERbKJ>xn%1cm7}B+*h@Me`h5jm?{ff47Vr_uJe2LBk6xq=--LagZ2B2M>ms`I z=!Y0ieqOBEZQeO}t%l8;L?Lsqdj7{_^#~)_9mZ;i!XOx*H@(wdF$`B8{=q1FX+MT$s{O2M$~)!1YO-L~7zSTcV9Xru zZ?@gM$cESaKB)&7rDFt65I9Momp~r@61Syi0U+mCBa7*33$Xz}x=7K8{4GAl=Wp5cBH;6f!Ul1z7^h&~oNusp?m==X+>c$R#eKMX2V)0hm%RjsLtb+f0I z{Jy#XdSqQd=Bp9bu;td%b64k{Pt~{2ossMIEcrZj`#^;Jr#(8nQS@U{bsR|L72jz6 z`rg+MsbyQ|#!!Z;!=&OkJTr2?s19n4ASECOdMV^D3hNexb#mQ~J0;(#Q|k`h_54v# z5l*SXDVg2Kh7yFVe3l^r+Ru>$pv(T^5dp}IkU)eteINrNA^}7c0P#e`fWqGLT=t{; ze)1zA30pxDsOy_V!Rq?cfvJGB4g|v_3uxlz1zvtR(y;C3r?N=HMd>t>XOIOj#f%Ox zHZ!2J-IkB^B2uhK9vEeKfp8N)1l7KiJV>JyWRe7}&}6~i2*(xSgesgMYI4#8 zaM2>kr6f06>WtL4!h-+rj$P3=w z(#SiAWX<%0gU8%q1o8Dtgx)7YIM_Aiv6Q4qlP5Pq=~whQWf7>(R!)<_y0A3nuDSj|R zaZe13DA*X7AWchvz41-?i4{6f|>6|~f;Af(~9JeiWuGAZp!R9XgD)RE!m_&*)iL}WQO@J_|>A@zrG@lr|%-9~sdv*-W`@0EkoT zO;4NZddD+vQcuo_KPbc^OlRt*hWrc4pc7z*TjIR6APfdbZQ63fzbG^=2#t!+tP0Ib zJ|l=lV*S%tjX*`zHIhX(36ZHKS^W{`s0E_qojuwa^3x11UFRA@jBCVN3Q|I9L7fxbVK_Rq01N0cJh~)$J?U4d z1f|v-c0^AXk5ey6GnK~NbSo1*SFX^>-H-PnRAd?7%!Y(}`Q9n1ziy^CRZ%1-|aq1%+=``DU4K&gETLc5`^KuwkLFK`Crf3!7$661tN> z=Mu7EBXW3J<(uvmRNVZWQc$lJ)X$tq6_(!|SuAW^C~Q;;o7KW*6UV*kee=8SZkpe& zRG(3+&tTd-RnTzj6}4dV%!wth(JoRu_RLq^ZI!p5VQu9hwWaq7$b99J!1+UqeANP9 zbt~{@esYNZ6n=-w?~wT&DPFL`wkswLk zHtP~I=oX{oE2YUAF;+!$y&?~aPu+-JjQ*P=(yB0ELT$zO$sULMnfn?|sJ zVF-)Un$oij@eYf^Yn}nkBR=>%UQ6@Dbhb0S=!;P~wwjAT8~+GhnCoY7fI*SLSYKOM zr-GzJ)QZN&X;LLR9ML&sorQgsf(p$bozd-vXC}GRb!ecM@lFC=1iCGnmc_^JU72{L z7qb~tXX(nAJh^6S=8ouZ@D_1_k+&a*OVx0l5^hk#4Kv45we>T{RK6<32LTuPh6TRi zR$SpXsQd<*-;l%mo;!}3I>J?%5xOSKsWgrsgSvI!_1(Q-qi^>M8sx?uYC$K+qMbM; zhv`q@yH&nh=DQtX@}@)jt?-9c{;+J^+2o5DG9o08c7n-UE%^e$G6_V|BtyVrUSh^t zJw~urO5;Z4OmdhbpNuY^!4|uhEt~jwtQCPd^2wNsXZe_~*iE55~GFFhVI#?-9er7%iVyx5G zRF1W&&Y^rY1LEp1tuEGJM*F}j7A|S7`!LE*zf)eZz#8fj)mE70ivbD?@m+?t)nKg6 zfe|ZX)_}3tD$Nrrqa9Gu#LH;Mn5&FsYbs;;6ECA3W3Dn{+r%2`vGR$R z(T*`!8P`^iRZqN(c8s~oSiPotta;*Pv}4Rw#@g&MitE72)=eH}tnB5Xmn=56>E)rs zl_to)()FhCv3Pt4%WazsMQjtBdo>?Bq-l^6jfak->p>Ho4|>3|NwegVG5$X|@;^P? zgD?DPDfd_K!nJfb?ck229UmAz{1Kf`4(Qt;7@u5>B>FZ6S~AAAF|-BP!X)JQh!_mY zS^5?Q&JnO&OtTx1eusR2mw=n=Is&lLoysdE3AnjPzl=!#lEU92;Nr^AH}BqQQ%GMP z9>ZFqfb@Hm_L~G;Ts8#a*aWH^3Erl-G-$cG`ra+Ha?S-b>ebze}#)CqN2K ziB=Bl!U64#koEw)>(g~1VOqMRGvm@e_{WLMfl@=CB|EgTL|~><+ncwX2hSH=><=bSw&4FY<;oDWdUFO>z-lD>9QTZ)0zr}&% zmz*Urvizj8Bu2T{hC8Ex2JQ^+6>wWw-YeLa5=!++1VyM*g*v9qfrnh6 zF!(+JD0K>L>KYP+syn~CQLWx9??0#C3V&YZ&&&LIn*?Nwm#RK1vl~+b{_`|70EnqX z!26asB{yzYt9N|2Lap8-?>npC3V%-J&&m8b2Vyf`s=8ZdHzD3dcy|%nT@I{|QQUoU zHQj{u0AW2qSPwX`nq^phSyxvHzc9t!EmzY`l|ZZ0{ClYq_BtxzG0N(Wm7!)JTRbKa zk4V@f(i43w9&=m{@ykt|{8&8pSS#Uj)Z0H{@pv#p#q;4rQL#`g#OxdxHv9b7Dur8D zmQ^gmwv4r4ELn-IUUM1k7~RY06KPIz^>1w#99>2FHJ8zj(Y=h==CYS09mstYnMi)Z(z{3Nqd$^XUT1r0A|TpY*7%CI-NpfPGPD~x=lH~ z2jG;Ic0HD?Go{;13!b;utk+dIGe?jCvn*%hRe`R4g&=PMQdh{_+4`6CWgrWi55etuxSM&U7y zbxh`uIU-s4t*TJnclr4(3V%Z7PsqG|!pbILt*%U)Fn_=AjKZH)`Li;QMF{4^{F-a( z$rWnDe0;uE;g75QahX5<;mYKjZ&&zUmG6~}J6ryMc0!jw0+=*et^BdvN;A=ib!UW? zU6||lSTpFDPxnLuW*Am(MIiO)7uHpItaVkGl(HsQ%r#v`T>5pL05qo+lJ05TCbG=+ zRyoqRanR;UfE;PuCYa1sSUJ+TS{>K0ELKB+6&U=UzVScU(jL+4=q_Q9e69_wohSR6hGO-jgVw z9iMyog4yL0t0CpoOzvO3lrt_3jSMAdYK~31>2qxl{{Ck_``Iw*HRX>(VF! zO?#QE(b%+?ErLF+)z?;So3>)PQ%=%%2z-~oza{YR2+R{;#d?)ouMoIS;AH}CGg}b~ zpqVYrAHNbEk7>L(1iM6g$omBVm@w!k_rC(r1k;|*K0L;Wv!XS;`h3^cwa9SpMjV!1 z`)kc^ohOOyRX2&PdDF}p%x0yU*3)l&%6w-2w&xA$a7=(v{AK3A3>he6vMa~zkP$m$ z`f7#P8Ddv zIlp3AcFngce2>ca$b3(Zj8;*WP4m_6SnNXPdM<}1v_k&ZW7Q(1P}#B@mNxR;UP z=Fmx>gbbH3Wv`qPft{9vBK2n_!>j>DHHx{0b^cH+_Q-$aQN)FuKx@Z6YwP^U_f4<#?OGMk6h7l$iix zJ3OSjl*nbJ1U9O$(c?u5XJZD9!{(6Dviy*tuautbUXzwjPc%V81nmKYtvBqoh$saO zYC!`CRC(3RiPz4j>fu)TT1?%~y)<`4DcGYH>~WAU?6Rod{2r#rx8G^L1G=+I<#);a zE?eiZ#b~($rPE%O-z)QbZ84mxn-#mdoic$?;?C|~)Rd?S9}Ry5-OE}R48xxR{)}Y_5Er@~&p(sHJ2PNf^OfS{AsTM?Pe`?5e3sVOAYY z6$>-mVe8^X2|+Sona{P8yT|6mTwTQ9C;fY%)jP>ua*^h{tw86^en1|q>#(49dfY-_ zJpuY^Hgg@6HM9Jvb69C}9n?p(V%_R6^)}sNs+PxC*+^mn8z;1Yz9U~6K-{b?2)bD# z7duYwW{q4T{xei_oLI^gdQDZ$>z_}aSZsM}q2(#1<&fHPNZxQ*sXC%oLG4=pW1&VB zw#glb?se?GJ*IRVR_Q*I+I~dtI7<6uY0oM_TJR8<`3&|b71rpz6TgLx*#2ayKZD8U zBzoUQmES1y8!i7pvLxwK_!gCKk&Szq-hwrCYZ{EK75@p;hH+UNeXI?v774dhi;mWA z_9JKC!CrLO6}6>t?>U~NZTK_Fk!fu)s~`I8YrC(Sy=2PgC>iQRznmqzYW57)7qa^y z55p!AgSmY%Ppxj7J37~`@Y_{>yUcI5r~wP=TikkX zVe2_%YoEHcPu}vh!at+(&&bC8(NG7ocVoQcQGe{R#!(w}Fk`x#t#bCkG(s#|>=!gUdI!3PgB(bIg5=VFBS2q9D|1Pt z`H)HrrisCtdb%V7X=b0U&*;-sLhJB=$*s$GZfES@re2;xCRx4w&~dAm*Q!DXEMQPC z-}$4R2kuq0suc&6ii2v!LG<$a4b;nPAG?>s0>@tJ<@>3V@BhgAIPr@Z!MH~a7%y^( zqu3ebVaH>58T2v8Gw5e9@RsLx@Ev^fZ<{?5*8FW^oy8e|b2jM!@-?r_7AgvpNT0<& zn`1TlxtSOEXA=+n&2if@NQXBFV#k7Gte*BU^A6dD9eA=sJ(cb6pW>4y&u+^w!!QIa zZPmh*1m`|ZPX$mG*w)EZhTmK_<#(S1bk+>8?_;*xgVSVy!j=1(66~ff5*Cpqvlkh* zfSom*uJsJw;pTouAM*R!d=H>2{@H{7OW+>?+I&OFf9~OijuiN~q5HqV(vbnQbUEi1wZDA158G-+s(jX40I%cei(zR*=6!O@w&cwI>1vK_o zmHrYQ%W1V&e8>Ojc*jC|$24$xM|&`PM{IsUt4<%o!d%3P>+=-kN(-u5(Wz+s8P=>$ z(2*1I6-KE4Oh{fJunR!r*`X6bSlf>#5)us0^^VZ}>?uS6%!1=To5#i=kZ$pjm>Ipms)#pi3=X_gzfO5PpHUyp97dV zzf@{8;QMuLuvl4ytN}1w76oQbE(vC%V4;N+T1cUV0c;-$bEP$JG<>Z=DQQ+qnrWH8 zzD?tXKNScs25J@pHMcg(&E5K~1a_){oie*Iu>bD!zxDjT_}rVHL&Nj;V%G>kMX0$E zp6xYL!}RaHaN(C{p5osYQpB@L@H47!wcSV-W6xy(HWN!di_5eHO;-9!XS_ zjnQi{D&j85idm->KqvJf#%V~SL@%5;m4F7ZC-q7X{(rvGpVKS-jqyrs++tQ; zqIRP$s~xP2fvJc!gD&Ifr%o-yqf6JN{2$78fB+3jnr}#ah1tAv8;{bvIv@9I=GfKw zxaedW+fL|XTV;bT5ErLlzkH3Y18$NS9y487W18ms@HmHFOGP_eHRd3XPt^TWxud@_5 z)5zm*vyXo}Z*59$R05mTz-F1tPNd2){_&!J&?znEEJ>!c2nuf-DXn}l>p~+myI`0d+hYv# z?97dGj^HDJR>Mp?^T}f{pWt8?^=7oYpbxW%R4^Tskjn%Q=$#*sAo_r$Kj}FDL@y_I z1%Y2D@D&1IB|svJ<{8vHS00eGuTRi0FoWDq?kxaxh9X%@ zrT5pGp|gd#^9T;2qqF1E6EbAZ&)B0ay+8?xV7#lZj$slI%;M=s8R}G|z>z*=t;d}zR(t96O^*VOZgT!E{_X$lAA#y$G|ed&9z$YrOLvQui=shM;43W{!i zMy@%k-%3G`TF^6df@xt-riAj@%irjK^V(e3+>3Lq-!J`MwcOC7H1wzqJz!!hj-%TX zoH=jqMNJ8iwCYpgyc^HnsGMz{eRj6;R`acAZ&hNR=GkQBT=U$sbCsAeefCbJzVQ8S z^WA6fR>JT(ctbym2DRKX%>DOz*WXL$LI_ z%k)GqW>0eU#0=~%J<;rOMv3SbJ<+5G*6gZ?g$^!TgvgjC-^Y57-pi;DU|=*Gz2rf~ zvX&29ztr+2FFdTaE&U<+`KX-{4~$9s^g=JU#>)=VTutL;5g!uL*MKzLcyCGXmENwF z%es`ZF14&H-D+=@zTNfBy>h`;rC_UCur<5YcFslTHhtgwJwa|bsx%x`8;(+|WfpM( zZLRiJfm*#yt?s@PoiDkArCCp@{8O@V>rKy2`sj68)B8=7GZ)c^2_r-^E^ERttLfx0 zt(2kEv&eKt<5+GGa1Jd@ZRTgHIs4lfMZ{G2l+T$0hNE&3{lFBT8P_ownGKY>fHUSt zT@E&cc=e`jr{b(Z>omM1_Tt1)gpNNVlFSYwdI9x-Gl)$8=;(DL%&>@W4KbC7ELBko zI>;#Y3Qng@JNmQD;a4&hy@ffQ``X2pVXbEm?0seG5rcjfR-qd9W7NBwhWUqc>%=t=xoM1pBl~9Lum(E-l%Ir6#$033R<8 zCJLOXmg!YI;$uo7Gd^q4fj;9rX2!QLmc?>0E+~CoGxN2AXp;?^u*LE?gLH<7*dRwF zhM$t^IzsaiibUrSp^pvu_<(F|Kq31nmDO1aC z_3T-`tK>`=#7oD&jN)J*Vcq2ES^T}%b$hQY>{NuEs<6}OK&~M*?P|>mZGX4iOt-@C zQTaVGzsJ#$G6NBgdzpUl7Rs7yL|J1$Fj)boXpkvF0 zj^mm~2ccE&kAVY^u@;y%zht#Q=U!r+9A_nUC&yXQ3Q?dqb+$A)mdo4p`@5yzsg{NP z`q^yMsIbUs%iK|gCQCbJY}-~n^12=Rt?->H-zgh+cC#=Ujo(_fS>zyV8J`ZM@d-4E ztmfcFbMQ%J)W@k>pmif>C3SbA>?T)O$2xmfed@=oPi}GR;5j0f(@C{he4<)^Rvr6x4Tc-&TcD!~ZqJ<-%JmSP~4DS3$gibP~bLk@y z=^8SX#t6g-kTGKECV}53@UIE{0fFxk_#uJ+LO>z#rv#P={2hS+)l!H6ePpeeauDfE zNBm-)L0l`3^gjEE(`TMO-50sod*Nc_*wG8U7o;LeQ9_`KKn;N=!bgfew$DO3LM{SB z_>W^R3!19&w-jjg;Z!*jJm)F)(b-fs*i!DpkuxUPTEQaNU`M%cKSfA`t))Jkrp4Ld z=B++hC^EtGo<<*?ePe??o{hd9c-i1H9>I4Qr&<`GJm0r*i8Dc6zVBHiu)%Dfh7CTN zRl?%vtP zYk~A6FZ)G|ppi&-?A|#5yyiu1impd*A)Cx&sc~#HZ;7aMNOdfbIVs>ACFlt2cYw~C z(zt|?)q#Zv7GVV}{wypZb`*A)9fh68j>68TMG2P*GO%Mo!ir_g%fb#??Ijz2wqk&fgxz9W50lBgK3{MT^g;t|{@vqT(6zS0rU_Ji%M)?umk=wM{0B0}_9 zD=VgXhr-buXt3my6gBLE+gz>zH!0)TWnX~moU)>!rO$o(VmFee4-T%Rg$YR=R|a%` z1RN*~!zenvY(`9+rr3fX%`@Vdsc(hf|({{g^JBeMY~$jZU*Qy0V%k)e4Ua09Nabzn`%-kJRgh4aRSFD zF@KK6p`W{3lwKP_Erv?OrX<+9P_0;>JrZN_0cmJFF(zsG#uPj|c35Z8wGzY5euT|4 z0f$yducoJ_5lPFt7P}sgjK*SOOw=k^jv{;g_(Z>vnuItE^M?ATIJ1%_H?FEKTd3Zq zfD61jwyo9af^E8{cJ{)dl%G%m(`Pc3tHK7bR(BfY9Y>%%tnZ;o2*CS2$KJcZ zaLQVKy_&Z*CDbFfym7y}@d$H+DLXL#86?YRI!8cV=g7W?qs5usZU>5%ZgFK!F%v2d z&uz}h8KJF}U%-i5Xk)aRL=#Yv5vqI~!i6oA0|{esj3<%Z_!?a+>?yN}bx>7pWK|^x zPU-jDv%i|S)trpV!Oco=vl_(YMLI-(Cu(EHWxf+}vu|K^Tha(|{7z-hHxL0zUzSSC z*~^@#EMl|&DzQ8353|)?98lqV3c0~B2j(rv1!09Z^0)uucnheRAh2L1o|8yDj}jXS zCu7|h`wg*k)^jMrFlUj$Juf4)lew+c#|{l(;>rZ2K1PU|Wd#nOSkKUrT%8y^}R)fd>s2C=Kn=(X|5 zF4sQEgqmDwwoy0ThWEm%v;bGyyjaq@P|`XlC?%b0N$1R&dsQ3MsxGB!r&_g>mcAEP z$<>?ZwlB6HS!h3Uw_0gGqqd(>iqERWXJ<~Q%Gb@DzR~~c`BX(690RZR0cYce#fI*M zhHiORue|d_YUiQU&J%Fp25D=gmGT5>A5ER&ubUZ~tyLQK+?t#_@Mr$onJcp)`|&@a z(;)EMF9r;@Ndl(nc{NUnF;coW7n&oP_bU&BsOUNY*i9%cB4l90$Y zsr)9{xSucymS}!R+Dl+Rz}kr|9pi2i+cR>2ZWm&k4q1lSCXxLwKCxY%#ENVJIfKNa zf-*^LPOTEz{aZbe3T6|joS5lM@YQ1$vN1FAC-g}NsjVC$aF{?3f#U>v3E*XM+#Fo# zMI@3*xTKR5P748~(*({CAX@}dAAs9pkw!*7CPs4+kjb67iB9fRJ=9=HNO+=T$Bo+D zEx#MTn+NjJ$ZSai+DT2C#+Y-%w{lUKwNaAQPz;hen5^^mkDTCAbhxF$xk$x)T2Cfs6 zurGUG@qLB*?|RrE9|v+SJxg50+WwG^C6X=(hTF-lx!Zhp?B}Sh{}HgxaH&wiVyI~$)bw`g zH*3`9!>Q&2XpP~6Sfdbzj)(U{y##PS5dPxG7e?68$&Cx)#<#I5{3)gRklK7mZaSQ5 z>b?`Z^GPJkqb0g|I4as3#9{6Lv|#n{lHb6?kUb=f@0+--T-;nAeDHyrx8?V!1TXT= zp`)F?A9V6ZcLjg2*8}K~Mm}!XR>PJqqB&{BOjBlW!gPR!^h*enJ_BHHbGqu^XhnKL zYI0jvZgOViPp_fH9inf0Xr&g{b>|azPEw1brkZzmdJggX3K zfE!4+I9u|JSkh-vLyiV!iiX52<5HuPiLmrc3(<1$ z5wIP{_z9kE+jpn7m?_PqggH4EEJ$bDSGu$O-w$SCUuF2{TJN z$xrs%juSCr7tDE*ar591oE?HYTdpvPDQL>7+d5hq>LB|{P|rzUZd-kG@DjFEMrx+! zpcavJvA*-saSDu}vKFX$y0E0uYL%a)R>`LQKSi!j6Szcx#CWK<`cYKI{@TSWt->+H z8|`SaJ$G9(NV5@_h)Vp=fjqmpEkD}({7*xpM=+E6V$qA-)Ftoi{bS$V2_<}94WCDI z3_eY383C~zn4sT6prgT}Wgx?cNFAC^Yuq1EID8v zFOr8{C{Jq>M~OI*HqjEwHNEY}BRJs9(h<^w5dgea0DDLes-qLb!`7fmPwpy2U7pX7 z=B*wVUrwISo$5vEY}snWuu{^kmUOSAMyyrZ&#CR_l;ZPh@p-5bt7@P<#5rE&sjAwv z>we4T#g=^wE&JpHXK{>FYX9++ah?>^h#UJzjR;7i_IkzNFmrWwi_)|&Sv&WcKl3-u zT$_Cgx&hM+M>zRh^0>>898O+H_S%m4GKOmR;XaqTFqglb4Zy)-xU+|QleK*u5-N>Q zi~101#9y_O%yV*p?zwD|nGRWoWF}FNfAVTX{iqiR2sRsDsGv2d5!XZ@U0Hf}@)&?% zlA-olZq6keGh_b)ebP8J7^XpFluDvUyrQl=d{^HZExFJsi6mBurqiXD3A_T}raWwR zS038Q&F5Q9Xv*4X$1gsO_zgRaFy?i$X@q@Vx6P~hb(;jFE%O;}Kg&p4^<9061fQ?a6j|le2KD0oK&zp9H(G zrY=u4yIPHx`;uGOL@{+AEsU`lhr<( zI&kFf`uQ!yXrE+^_DKh$?Z#ohec`=2wQZk5%bE7eEeFgEYnEEbUr2x#!g6MBUW#e)&n!TEcHe$4=V-@ zLkqdPCe^RrH67E0MuBYIbD3$v?AW)dhtj$u^w81#=xLo&gPH(sAh=W7@dffA7ao2E1eT$k^}=M{nE}+EbQ}Pc^xN) zaxUz{ED`Ivf$1Ro6{bVLXpX}Uj4*8{=4Ee&&KW;|(-qyuHFxZSSg?lJew3Utk{eM6zij`M5;4L4?O08IvMPSd7P3o1uzH!mj6NBHmUipMJLmdyQHrm zrSw%w?$h0`lgBK9Z=t%`z8}PAzet?^m2~NsHWK2{5OWx}$OSC9FiK%8OEy;}cbY4e zO|PzOuWrql17mMw&VjKfTTqv!$wLs93o@b-r=@<~j-c*1Em2QXBmC&g1ob)Ok%{5N z5G-yR6fM5o9CB9VKO}N=7WD;kRNaD&=c(cf+?aJZ=^?vrI;;s|(xC=NzXQm81=cdx z;H>vZt#ZrH8g#DCy$ae%a%k16() z^#_v&?&P~LA$iHmbC>0fyA*!6%I}u>-FDvNa`ICy3ki|L)jvP^;#_m`vchju`E4@4 z&BBC`q^*&g|z@k%sRX-M?CN@ zj%7+TY00kJ8JqPJrA7$IF1NU^w6x60@^PqMTl(BoOoNImTzBmWX(5f1;?X_=((|YoUJm_XC6rqkw`19 z+tK%s%OWaU5vmpHHU=a8*V)!SX*rz79Y6yGC0{%({V3~-PAKeSDQ^Q;C6&ByTCa>#|Yqu+fJJiA*%!6zez5U!b zKP4A-D1{wrVaIp#)g5PlR8;e3hun4mc6U1u{^;Q8R8i@T*uxN?SAo?ac@+`0>4Ax_Xspm)|&}zVlMLA4A4v;YzdAT^KU6&+ZL#A z5iB)!1E%z?6kR8=r$PywQUj+zBzf97LU4&M@wF~-01y3{zMCu*Y=LmsY0QTzfir4= zG@kr9hThxI%rC=`5KY1?EPl>2=&hi!xAmc|rYxh(I@ zRe~a4`w|D>ECKnhw_buxPwYI4xeFz5P7R!+m!SC}f=hguPyi4Lr*E}|f@v>z=Bi3N z3af(2x7r9|Su9oOQD!EnE%R+z;!MyU^r0hU1HCa&6QB*hBT*OtiL|w9!AL|L8^C^F zp?_isBl~Dv%Re%HJz_r9M8>a^f~`ZRP&g%^LyEmD4JBe)$n?^JhF=(onLaKmNuNN= z17+)v53^$iuw`DN5kcSg>xYVC{JM0B;t=mOF_7reElPEajmO!JFpZ=#n?K=Y_QiE_ zN8^D+`|7l?o)FsLm-V%p>nx8Blm~VMJ1*1yVrkdDPQ;UWriEQ;kwJS)}_O8ZSh^)>;jXGRh< zo~k;dJN{I7O5bOFNFtpAZY~}`)xh4G$Fo$&d3MR%&inR0YV}rf$y~8{-{+cT=bhp< z$<8~)mCDAQ;U%&O60zk)3ynYn7dMitCV_cZ#c$jT`f#o~{h6_l2Sve~JJ? zD;rXw2ASQdQ1y+*8;z;zZ8Ez}$5LLP#tRAzP{>s_r9u@q>TcAfLZvrqZq%fzTW;mw z%1{3wDo;@X5VbxP$}^+YrmAb5!Hp)SJcMNhuV<-q1v$%*ox0h`*eGW6t_Gl=cn-mO;Z7Ye*o#RibM3T%v7haH5q$LZ#(@Tth{;R&#*sOtLAGF Iz$n%K1KI}@od5s; From b73780484f4fb8cbf8e6f5765a5447e4d0483269 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 14 Aug 2025 22:27:02 -0700 Subject: [PATCH 17/17] Apply Ruff formatting --- setup.py | 117 +- struct2tensor/benchmarks/ops_benchmark.py | 112 +- .../benchmarks/struct2tensor_benchmark.py | 2057 +++++++++-------- .../struct2tensor_benchmark_util.py | 460 ++-- struct2tensor/calculate.py | 716 +++--- struct2tensor/calculate_options.py | 71 +- struct2tensor/calculate_test.py | 304 +-- struct2tensor/calculate_with_source_paths.py | 156 +- .../calculate_with_source_paths_test.py | 326 +-- struct2tensor/create_expression.py | 172 +- struct2tensor/create_expression_test.py | 235 +- struct2tensor/expression.py | 1223 +++++----- struct2tensor/expression_add.py | 330 +-- struct2tensor/expression_add_test.py | 167 +- struct2tensor/expression_impl/apply_schema.py | 252 +- .../expression_impl/apply_schema_test.py | 144 +- struct2tensor/expression_impl/broadcast.py | 389 ++-- .../expression_impl/broadcast_test.py | 430 ++-- struct2tensor/expression_impl/depth_limit.py | 93 +- .../expression_impl/depth_limit_test.py | 148 +- .../expression_impl/filter_expression.py | 554 ++--- .../expression_impl/filter_expression_test.py | 396 ++-- struct2tensor/expression_impl/index.py | 263 ++- struct2tensor/expression_impl/index_test.py | 112 +- struct2tensor/expression_impl/map_prensor.py | 543 +++-- .../expression_impl/map_prensor_test.py | 320 +-- .../expression_impl/map_prensor_to_prensor.py | 588 ++--- .../map_prensor_to_prensor_test.py | 237 +- struct2tensor/expression_impl/map_values.py | 266 ++- .../expression_impl/map_values_test.py | 93 +- struct2tensor/expression_impl/parquet.py | 920 ++++---- struct2tensor/expression_impl/parquet_test.py | 314 +-- .../expression_impl/parse_message_level_ex.py | 337 +-- .../parse_message_level_ex_test.py | 249 +- struct2tensor/expression_impl/placeholder.py | 278 +-- .../expression_impl/placeholder_test.py | 123 +- struct2tensor/expression_impl/project.py | 154 +- struct2tensor/expression_impl/project_test.py | 44 +- struct2tensor/expression_impl/promote.py | 478 ++-- .../expression_impl/promote_and_broadcast.py | 113 +- .../promote_and_broadcast_test.py | 35 +- struct2tensor/expression_impl/promote_test.py | 1070 +++++---- struct2tensor/expression_impl/proto.py | 1188 +++++----- struct2tensor/expression_impl/proto_test.py | 847 +++---- .../expression_impl/proto_test_util.py | 30 +- .../raw_parquet_dataset_test.py | 2017 ++++++++-------- struct2tensor/expression_impl/reroot.py | 253 +- struct2tensor/expression_impl/reroot_test.py | 237 +- struct2tensor/expression_impl/size.py | 227 +- struct2tensor/expression_impl/size_test.py | 73 +- .../expression_impl/slice_expression.py | 306 +-- .../expression_impl/slice_expression_test.py | 447 ++-- struct2tensor/expression_test.py | 456 ++-- struct2tensor/ops/file_descriptor_set.py | 198 +- struct2tensor/ops/file_descriptor_set_test.py | 23 +- struct2tensor/ops/gen_decode_proto_map_op.py | 3 +- struct2tensor/ops/gen_decode_proto_sparse.py | 3 +- .../ops/gen_equi_join_any_indices.py | 3 +- struct2tensor/ops/gen_equi_join_indices.py | 3 +- struct2tensor/ops/gen_parquet_dataset.py | 3 +- struct2tensor/ops/gen_run_length_before.py | 3 +- struct2tensor/ops/struct2tensor_ops.py | 487 ++-- struct2tensor/ops/struct2tensor_ops_test.py | 1427 ++++++------ struct2tensor/path.py | 702 +++--- struct2tensor/path_test.py | 537 ++--- struct2tensor/prensor.py | 1373 +++++------ struct2tensor/prensor_test.py | 490 ++-- struct2tensor/prensor_to_structured_tensor.py | 120 +- .../prensor_to_structured_tensor_test.py | 174 +- struct2tensor/prensor_util.py | 92 +- struct2tensor/prensor_value.py | 426 ++-- struct2tensor/prensor_value_test.py | 123 +- struct2tensor/struct2tensor_module_test.py | 96 +- struct2tensor/structured_tensor_to_prensor.py | 461 ++-- .../structured_tensor_to_prensor_test.py | 1089 ++++----- struct2tensor/test/expression_test_util.py | 346 +-- struct2tensor/test/prensor_test_util.py | 497 ++-- struct2tensor/tools/build_docs.py | 206 +- struct2tensor/version.py | 2 +- 79 files changed, 15822 insertions(+), 14535 deletions(-) diff --git a/setup.py b/setup.py index 6212f14..2482936 100644 --- a/setup.py +++ b/setup.py @@ -25,98 +25,97 @@ # built by setuptools, it will be incorrectly treated as a purelib. The # following works around that bug. class _InstallPlatlib(install): - - def finalize_options(self): - install.finalize_options(self) - self.install_lib = self.install_platlib + def finalize_options(self): + install.finalize_options(self) + self.install_lib = self.install_platlib class BinaryDistribution(Distribution): - """This class is needed in order to create OS specific wheels.""" + """This class is needed in order to create OS specific wheels.""" - def is_pure(self): - return False + def is_pure(self): + return False - def has_ext_modules(self): - return True + def has_ext_modules(self): + return True def select_constraint(default, nightly=None, git_master=None): - """Select dependency constraint based on TFX_DEPENDENCY_SELECTOR env var.""" - selector = os.environ.get('TFX_DEPENDENCY_SELECTOR') - if selector == 'UNCONSTRAINED': - return '' - if selector == 'NIGHTLY' and nightly is not None: - return nightly - if selector == 'GIT_MASTER' and git_master is not None: - return git_master - return default + """Select dependency constraint based on TFX_DEPENDENCY_SELECTOR env var.""" + selector = os.environ.get("TFX_DEPENDENCY_SELECTOR") + if selector == "UNCONSTRAINED": + return "" + if selector == "NIGHTLY" and nightly is not None: + return nightly + if selector == "GIT_MASTER" and git_master is not None: + return git_master + return default # Get the long description from the README file. -with open('README.md') as fp: - _LONG_DESCRIPTION = fp.read() +with open("README.md") as fp: + _LONG_DESCRIPTION = fp.read() -with open('struct2tensor/version.py') as fp: - globals_dict = {} - exec(fp.read(), globals_dict) # pylint: disable=exec-used -__version__ = globals_dict['__version__'] +with open("struct2tensor/version.py") as fp: + globals_dict = {} + exec(fp.read(), globals_dict) # pylint: disable=exec-used +__version__ = globals_dict["__version__"] setup( - name='struct2tensor', + name="struct2tensor", version=__version__, description=( - 'Struct2Tensor is a package for parsing and manipulating' - ' structured data for TensorFlow' + "Struct2Tensor is a package for parsing and manipulating" + " structured data for TensorFlow" ), - author='Google LLC', - author_email='tensorflow-extended-dev@googlegroups.com', + author="Google LLC", + author_email="tensorflow-extended-dev@googlegroups.com", # Contained modules and scripts. packages=find_packages(), install_requires=[ # TODO(b/263060885): Remove the explicit numpy dependency once TF works # with numpy>=1.24. - 'numpy>=1.22', + "numpy>=1.22", 'protobuf>=4.25.2,<6.0.0;python_version>="3.11"', 'protobuf>=4.21.6,<6.0.0;python_version<"3.11"', - 'tensorflow>=2.17,<2.18', - 'tensorflow-metadata' + "tensorflow>=2.17,<2.18", + "tensorflow-metadata" + select_constraint( - default='>=1.17.0,<1.18.0', - nightly='>=1.18.0.dev', - git_master='@git+https://github.com/tensorflow/metadata@master', + default=">=1.17.0,<1.18.0", + nightly=">=1.18.0.dev", + git_master="@git+https://github.com/tensorflow/metadata@master", ), - 'pyarrow>=10,<11', + "pyarrow>=10,<11", ], # Add in any packaged data. include_package_data=True, - package_data={'': ['*.lib', '*.so']}, + package_data={"": ["*.lib", "*.so"]}, zip_safe=False, distclass=BinaryDistribution, # PyPI package information. classifiers=[ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'Intended Audience :: Education', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: MacOS :: MacOS X', - 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3 :: Only', - 'Topic :: Scientific/Engineering', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: Scientific/Engineering :: Mathematics', - 'Topic :: Software Development', - 'Topic :: Software Development :: Libraries', - 'Topic :: Software Development :: Libraries :: Python Modules', + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: MacOS :: MacOS X", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3 :: Only", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", ], - license='Apache 2.0', - keywords='tensorflow protobuf machine learning', + license="Apache 2.0", + keywords="tensorflow protobuf machine learning", long_description=_LONG_DESCRIPTION, - long_description_content_type='text/markdown', - cmdclass={'install': _InstallPlatlib}, + long_description_content_type="text/markdown", + cmdclass={"install": _InstallPlatlib}, ) diff --git a/struct2tensor/benchmarks/ops_benchmark.py b/struct2tensor/benchmarks/ops_benchmark.py index 045ad3a..e55a8e9 100644 --- a/struct2tensor/benchmarks/ops_benchmark.py +++ b/struct2tensor/benchmarks/ops_benchmark.py @@ -38,64 +38,66 @@ class EquiJoinIndicesBenchmarks(struct2tensor_benchmark_util.OpsBenchmarks): - """Benchmarks for EquiJoinIndices.""" - - @parameterized.named_parameters(*[ - dict( - testcase_name="equi_join_indices_monotonic_increasing", - fn_name="equi_join_indices_monotonic_increasing", - fn_args=[], - data_key="monotonic_increasing", - ), - dict( - testcase_name="equi_join_indices_random", - fn_name="equi_join_indices_random", - fn_args=[], - data_key="random", - ), - ]) - def test_equi_join_indices(self, fn_name, fn_args, data_key): - - def benchmark_fn(session): - a = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,)) - b = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,)) - result = struct2tensor_ops.equi_join_indices(a, b) - with tf.control_dependencies(result): - x = tf.constant(1) - return session.make_callable(x, feed_list=[a, b]) - - self.run_benchmarks(fn_name, benchmark_fn, fn_args, data_key) + """Benchmarks for EquiJoinIndices.""" + + @parameterized.named_parameters( + *[ + dict( + testcase_name="equi_join_indices_monotonic_increasing", + fn_name="equi_join_indices_monotonic_increasing", + fn_args=[], + data_key="monotonic_increasing", + ), + dict( + testcase_name="equi_join_indices_random", + fn_name="equi_join_indices_random", + fn_args=[], + data_key="random", + ), + ] + ) + def test_equi_join_indices(self, fn_name, fn_args, data_key): + def benchmark_fn(session): + a = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,)) + b = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,)) + result = struct2tensor_ops.equi_join_indices(a, b) + with tf.control_dependencies(result): + x = tf.constant(1) + return session.make_callable(x, feed_list=[a, b]) + + self.run_benchmarks(fn_name, benchmark_fn, fn_args, data_key) class EquiJoinAnyIndicesBenchmarks(struct2tensor_benchmark_util.OpsBenchmarks): - """Benchmarks for EquiJoinAnyIndices.""" - - @parameterized.named_parameters(*[ - dict( - testcase_name="equi_join_any_indices_monotonic_increasing", - fn_name="equi_join_any_indices_monotonic_increasing", - fn_args=[], - data_key="monotonic_increasing", - ), - dict( - testcase_name="equi_join_any_indices_random", - fn_name="equi_join_any_indices_random", - fn_args=[], - data_key="random", - ), - ]) - def test_equi_join_indices(self, fn_name, fn_args, data_key): - - def benchmark_fn(session): - a = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,)) - b = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,)) - result = struct2tensor_ops.equi_join_any_indices(a, b) - with tf.control_dependencies(result): - x = tf.constant(1) - return session.make_callable(x, feed_list=[a, b]) - - self.run_benchmarks(fn_name, benchmark_fn, fn_args, data_key) + """Benchmarks for EquiJoinAnyIndices.""" + + @parameterized.named_parameters( + *[ + dict( + testcase_name="equi_join_any_indices_monotonic_increasing", + fn_name="equi_join_any_indices_monotonic_increasing", + fn_args=[], + data_key="monotonic_increasing", + ), + dict( + testcase_name="equi_join_any_indices_random", + fn_name="equi_join_any_indices_random", + fn_args=[], + data_key="random", + ), + ] + ) + def test_equi_join_indices(self, fn_name, fn_args, data_key): + def benchmark_fn(session): + a = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,)) + b = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,)) + result = struct2tensor_ops.equi_join_any_indices(a, b) + with tf.control_dependencies(result): + x = tf.constant(1) + return session.make_callable(x, feed_list=[a, b]) + + self.run_benchmarks(fn_name, benchmark_fn, fn_args, data_key) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/struct2tensor/benchmarks/struct2tensor_benchmark.py b/struct2tensor/benchmarks/struct2tensor_benchmark.py index bf0c268..9a8a972 100644 --- a/struct2tensor/benchmarks/struct2tensor_benchmark.py +++ b/struct2tensor/benchmarks/struct2tensor_benchmark.py @@ -21,1044 +21,1135 @@ class ProjectBenchmarks(struct2tensor_benchmark_util.ProtoDataBenchmarks): - """Benchmarks for projecting fields.""" + """Benchmarks for projecting fields.""" - # pylint: disable=g-complex-comprehension - @parameterized.named_parameters(*[ - dict( - testcase_name="project_1_deep_int_fields", - fn_name="project_1_deep_int_fields", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [s2t.path.Path(["int_values_1"])] - ], - proto_list_key="deep_protos", - ), - dict( - testcase_name="project_2_deep_int_fields", - fn_name="project_2_deep_int_fields", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [s2t.path.Path(["child_1", "int_values_2"])] - ], - proto_list_key="deep_protos", - ), - dict( - testcase_name="project_3_deep_int_fields", - fn_name="project_3_deep_int_fields", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [s2t.path.Path(["child_1", "child_2", "int_values_3"])] - ], - proto_list_key="deep_protos", - ), - dict( - testcase_name="project_4_deep_int_fields", - fn_name="project_4_deep_int_fields", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [ - s2t.path.Path( - ["child_1", "child_2", "child_3", "int_values_4"]) - ] - ], - proto_list_key="deep_protos", - ), - dict( - testcase_name="project_5_deep_int_fields", - fn_name="project_5_deep_int_fields", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [ - s2t.path.Path([ - "child_1", "child_2", "child_3", "child_4", "int_values_5" - ]) - ] - ], - proto_list_key="deep_protos", - ), - dict( - testcase_name="project_1_flat_int_fields", - fn_name="project_1_flat_int_fields", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [s2t.path.Path(["int_values_1"])] - ], - proto_list_key="flat_protos", - ), - dict( - testcase_name="project_2_flat_int_fields", - fn_name="project_2_flat_int_fields", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [ - s2t.path.Path(["int_values_1"]), - s2t.path.Path(["int_values_2"]) - ] - ], - proto_list_key="flat_protos", - ), - dict( - testcase_name="project_3_flat_int_fields", - fn_name="project_3_flat_int_fields", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [ - s2t.path.Path(["int_values_1"]), - s2t.path.Path(["int_values_2"]), - s2t.path.Path(["int_values_3"]) - ] - ], - proto_list_key="flat_protos", - ), - dict( - testcase_name="project_4_flat_int_fields", - fn_name="project_4_flat_int_fields", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [ - s2t.path.Path(["int_values_1"]), - s2t.path.Path(["int_values_2"]), - s2t.path.Path(["int_values_3"]), - s2t.path.Path(["int_values_4"]) - ] - ], - proto_list_key="flat_protos", - ), - dict( - testcase_name="project_5_flat_int_fields", - fn_name="project_5_flat_int_fields", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [ - s2t.path.Path(["int_values_1"]), - s2t.path.Path(["int_values_2"]), - s2t.path.Path(["int_values_3"]), - s2t.path.Path(["int_values_4"]), - s2t.path.Path(["int_values_5"]) - ] - ], - proto_list_key="flat_protos", - ), - dict( - testcase_name="project_1_deep_float_fields", - fn_name="project_1_deep_float_fields", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [s2t.path.Path(["float_values_1"])] - ], - proto_list_key="deep_protos", - ), - dict( - testcase_name="project_2_deep_float_fields", - fn_name="project_2_deep_float_fields", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [s2t.path.Path(["child_1", "float_values_2"])] - ], - proto_list_key="deep_protos", - ), - dict( - testcase_name="project_3_deep_float_fields", - fn_name="project_3_deep_float_fields", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [s2t.path.Path(["child_1", "child_2", "float_values_3"])] - ], - proto_list_key="deep_protos", - ), - dict( - testcase_name="project_4_deep_float_fields", - fn_name="project_4_deep_float_fields", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [ - s2t.path.Path( - ["child_1", "child_2", "child_3", "float_values_4"]) - ] - ], - proto_list_key="deep_protos", - ), - dict( - testcase_name="project_5_deep_float_fields", - fn_name="project_5_deep_float_fields", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [ - s2t.path.Path([ - "child_1", "child_2", "child_3", "child_4", - "float_values_5" - ]) - ] - ], - proto_list_key="deep_protos", - ), - dict( - testcase_name="project_1_flat_float_fields", - fn_name="project_1_flat_float_fields", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [s2t.path.Path(["float_values_1"])] - ], - proto_list_key="flat_protos", - ), - dict( - testcase_name="project_2_flat_float_fields", - fn_name="project_2_flat_float_fields", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [ - s2t.path.Path(["float_values_1"]), - s2t.path.Path(["float_values_2"]) - ] - ], - proto_list_key="flat_protos", - ), - dict( - testcase_name="project_3_flat_float_fields", - fn_name="project_3_flat_float_fields", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [ - s2t.path.Path(["float_values_1"]), - s2t.path.Path(["float_values_2"]), - s2t.path.Path(["float_values_3"]) - ] - ], - proto_list_key="flat_protos", - ), - dict( - testcase_name="project_4_flat_float_fields", - fn_name="project_4_flat_float_fields", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [ - s2t.path.Path(["float_values_1"]), - s2t.path.Path(["float_values_2"]), - s2t.path.Path(["float_values_3"]), - s2t.path.Path(["float_values_4"]) - ] - ], - proto_list_key="flat_protos", - ), - dict( - testcase_name="project_5_flat_float_fields", - fn_name="project_5_flat_float_fields", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [ - s2t.path.Path(["float_values_1"]), - s2t.path.Path(["float_values_2"]), - s2t.path.Path(["float_values_3"]), - s2t.path.Path(["float_values_4"]), - s2t.path.Path(["float_values_5"]) - ] - ], - proto_list_key="flat_protos", - ), - dict( - testcase_name="project_1_deep_bytes_fields", - fn_name="project_1_deep_bytes_fields", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [s2t.path.Path(["bytes_values_1"])] - ], - proto_list_key="deep_protos", - ), - dict( - testcase_name="project_2_deep_bytes_fields", - fn_name="project_2_deep_bytes_fields", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [s2t.path.Path(["child_1", "bytes_values_2"])] - ], - proto_list_key="deep_protos", - ), - dict( - testcase_name="project_3_deep_bytes_fields", - fn_name="project_3_deep_bytes_fields", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [s2t.path.Path(["child_1", "child_2", "bytes_values_3"])] - ], - proto_list_key="deep_protos", - ), - dict( - testcase_name="project_4_deep_bytes_fields", - fn_name="project_4_deep_bytes_fields", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [ - s2t.path.Path( - ["child_1", "child_2", "child_3", "bytes_values_4"]) - ] - ], - proto_list_key="deep_protos", - ), - dict( - testcase_name="project_5_deep_bytes_fields", - fn_name="project_5_deep_bytes_fields", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [ - s2t.path.Path([ - "child_1", "child_2", "child_3", "child_4", - "bytes_values_5" - ]) - ] - ], - proto_list_key="deep_protos", - ), - dict( - testcase_name="project_1_flat_bytes_fields", - fn_name="project_1_flat_bytes_fields", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [s2t.path.Path(["bytes_values_1"])] - ], - proto_list_key="flat_protos", - ), - dict( - testcase_name="project_2_flat_bytes_fields", - fn_name="project_2_flat_bytes_fields", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [ - s2t.path.Path(["bytes_values_1"]), - s2t.path.Path(["bytes_values_2"]) - ] - ], - proto_list_key="flat_protos", - ), - dict( - testcase_name="project_3_flat_bytes_fields", - fn_name="project_3_flat_bytes_fields", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [ - s2t.path.Path(["bytes_values_1"]), - s2t.path.Path(["bytes_values_2"]), - s2t.path.Path(["bytes_values_3"]) - ] - ], - proto_list_key="flat_protos", - ), - dict( - testcase_name="project_4_flat_bytes_fields", - fn_name="project_4_flat_bytes_fields", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [ - s2t.path.Path(["bytes_values_1"]), - s2t.path.Path(["bytes_values_2"]), - s2t.path.Path(["bytes_values_3"]), - s2t.path.Path(["bytes_values_4"]) - ] - ], - proto_list_key="flat_protos", - ), - dict( - testcase_name="project_5_flat_bytes_fields", - fn_name="project_5_flat_bytes_fields", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [ - s2t.path.Path(["bytes_values_1"]), - s2t.path.Path(["bytes_values_2"]), - s2t.path.Path(["bytes_values_3"]), - s2t.path.Path(["bytes_values_4"]), - s2t.path.Path(["bytes_values_5"]) - ] - ], - proto_list_key="flat_protos", - ), - ]) - # pylint: enable=g-complex-comprehension - def test_project(self, fn_name, fn_args, proto_list_key): - self.run_benchmarks(fn_name, _get_project_fn, fn_args, proto_list_key) + # pylint: disable=g-complex-comprehension + @parameterized.named_parameters( + *[ + dict( + testcase_name="project_1_deep_int_fields", + fn_name="project_1_deep_int_fields", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["int_values_1"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="project_2_deep_int_fields", + fn_name="project_2_deep_int_fields", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["child_1", "int_values_2"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="project_3_deep_int_fields", + fn_name="project_3_deep_int_fields", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["child_1", "child_2", "int_values_3"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="project_4_deep_int_fields", + fn_name="project_4_deep_int_fields", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["child_1", "child_2", "child_3", "int_values_4"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="project_5_deep_int_fields", + fn_name="project_5_deep_int_fields", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [ + s2t.path.Path( + ["child_1", "child_2", "child_3", "child_4", "int_values_5"] + ) + ], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="project_1_flat_int_fields", + fn_name="project_1_flat_int_fields", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [s2t.path.Path(["int_values_1"])], + ], + proto_list_key="flat_protos", + ), + dict( + testcase_name="project_2_flat_int_fields", + fn_name="project_2_flat_int_fields", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [s2t.path.Path(["int_values_1"]), s2t.path.Path(["int_values_2"])], + ], + proto_list_key="flat_protos", + ), + dict( + testcase_name="project_3_flat_int_fields", + fn_name="project_3_flat_int_fields", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [ + s2t.path.Path(["int_values_1"]), + s2t.path.Path(["int_values_2"]), + s2t.path.Path(["int_values_3"]), + ], + ], + proto_list_key="flat_protos", + ), + dict( + testcase_name="project_4_flat_int_fields", + fn_name="project_4_flat_int_fields", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [ + s2t.path.Path(["int_values_1"]), + s2t.path.Path(["int_values_2"]), + s2t.path.Path(["int_values_3"]), + s2t.path.Path(["int_values_4"]), + ], + ], + proto_list_key="flat_protos", + ), + dict( + testcase_name="project_5_flat_int_fields", + fn_name="project_5_flat_int_fields", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [ + s2t.path.Path(["int_values_1"]), + s2t.path.Path(["int_values_2"]), + s2t.path.Path(["int_values_3"]), + s2t.path.Path(["int_values_4"]), + s2t.path.Path(["int_values_5"]), + ], + ], + proto_list_key="flat_protos", + ), + dict( + testcase_name="project_1_deep_float_fields", + fn_name="project_1_deep_float_fields", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["float_values_1"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="project_2_deep_float_fields", + fn_name="project_2_deep_float_fields", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["child_1", "float_values_2"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="project_3_deep_float_fields", + fn_name="project_3_deep_float_fields", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["child_1", "child_2", "float_values_3"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="project_4_deep_float_fields", + fn_name="project_4_deep_float_fields", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [ + s2t.path.Path( + ["child_1", "child_2", "child_3", "float_values_4"] + ) + ], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="project_5_deep_float_fields", + fn_name="project_5_deep_float_fields", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [ + s2t.path.Path( + [ + "child_1", + "child_2", + "child_3", + "child_4", + "float_values_5", + ] + ) + ], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="project_1_flat_float_fields", + fn_name="project_1_flat_float_fields", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [s2t.path.Path(["float_values_1"])], + ], + proto_list_key="flat_protos", + ), + dict( + testcase_name="project_2_flat_float_fields", + fn_name="project_2_flat_float_fields", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [ + s2t.path.Path(["float_values_1"]), + s2t.path.Path(["float_values_2"]), + ], + ], + proto_list_key="flat_protos", + ), + dict( + testcase_name="project_3_flat_float_fields", + fn_name="project_3_flat_float_fields", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [ + s2t.path.Path(["float_values_1"]), + s2t.path.Path(["float_values_2"]), + s2t.path.Path(["float_values_3"]), + ], + ], + proto_list_key="flat_protos", + ), + dict( + testcase_name="project_4_flat_float_fields", + fn_name="project_4_flat_float_fields", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [ + s2t.path.Path(["float_values_1"]), + s2t.path.Path(["float_values_2"]), + s2t.path.Path(["float_values_3"]), + s2t.path.Path(["float_values_4"]), + ], + ], + proto_list_key="flat_protos", + ), + dict( + testcase_name="project_5_flat_float_fields", + fn_name="project_5_flat_float_fields", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [ + s2t.path.Path(["float_values_1"]), + s2t.path.Path(["float_values_2"]), + s2t.path.Path(["float_values_3"]), + s2t.path.Path(["float_values_4"]), + s2t.path.Path(["float_values_5"]), + ], + ], + proto_list_key="flat_protos", + ), + dict( + testcase_name="project_1_deep_bytes_fields", + fn_name="project_1_deep_bytes_fields", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["bytes_values_1"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="project_2_deep_bytes_fields", + fn_name="project_2_deep_bytes_fields", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["child_1", "bytes_values_2"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="project_3_deep_bytes_fields", + fn_name="project_3_deep_bytes_fields", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["child_1", "child_2", "bytes_values_3"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="project_4_deep_bytes_fields", + fn_name="project_4_deep_bytes_fields", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [ + s2t.path.Path( + ["child_1", "child_2", "child_3", "bytes_values_4"] + ) + ], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="project_5_deep_bytes_fields", + fn_name="project_5_deep_bytes_fields", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [ + s2t.path.Path( + [ + "child_1", + "child_2", + "child_3", + "child_4", + "bytes_values_5", + ] + ) + ], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="project_1_flat_bytes_fields", + fn_name="project_1_flat_bytes_fields", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [s2t.path.Path(["bytes_values_1"])], + ], + proto_list_key="flat_protos", + ), + dict( + testcase_name="project_2_flat_bytes_fields", + fn_name="project_2_flat_bytes_fields", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [ + s2t.path.Path(["bytes_values_1"]), + s2t.path.Path(["bytes_values_2"]), + ], + ], + proto_list_key="flat_protos", + ), + dict( + testcase_name="project_3_flat_bytes_fields", + fn_name="project_3_flat_bytes_fields", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [ + s2t.path.Path(["bytes_values_1"]), + s2t.path.Path(["bytes_values_2"]), + s2t.path.Path(["bytes_values_3"]), + ], + ], + proto_list_key="flat_protos", + ), + dict( + testcase_name="project_4_flat_bytes_fields", + fn_name="project_4_flat_bytes_fields", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [ + s2t.path.Path(["bytes_values_1"]), + s2t.path.Path(["bytes_values_2"]), + s2t.path.Path(["bytes_values_3"]), + s2t.path.Path(["bytes_values_4"]), + ], + ], + proto_list_key="flat_protos", + ), + dict( + testcase_name="project_5_flat_bytes_fields", + fn_name="project_5_flat_bytes_fields", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [ + s2t.path.Path(["bytes_values_1"]), + s2t.path.Path(["bytes_values_2"]), + s2t.path.Path(["bytes_values_3"]), + s2t.path.Path(["bytes_values_4"]), + s2t.path.Path(["bytes_values_5"]), + ], + ], + proto_list_key="flat_protos", + ), + ] + ) + # pylint: enable=g-complex-comprehension + def test_project(self, fn_name, fn_args, proto_list_key): + self.run_benchmarks(fn_name, _get_project_fn, fn_args, proto_list_key) class PromoteBenchmarks(struct2tensor_benchmark_util.ProtoDataBenchmarks): - """Benchmarks for promoting fields.""" + """Benchmarks for promoting fields.""" - @parameterized.named_parameters(*[ - dict( - testcase_name="promote_1", - fn_name="promote_1", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - s2t.path.Path(["child_1", "int_values_2"]) - ], - proto_list_key="deep_protos"), - dict( - testcase_name="promote_2", - fn_name="promote_2", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - s2t.path.Path(["child_1", "child_2", "int_values_3"]) - ], - proto_list_key="deep_protos"), - dict( - testcase_name="promote_3", - fn_name="promote_3", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - s2t.path.Path(["child_1", "child_2", "child_3", "int_values_4"]) - ], - proto_list_key="deep_protos"), - dict( - testcase_name="promote_4", - fn_name="promote_4", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - s2t.path.Path( - ["child_1", "child_2", "child_3", "child_4", "int_values_5"]) - ], - proto_list_key="deep_protos"), - ]) - def test_promote(self, fn_name, fn_args, proto_list_key): - self.run_benchmarks(fn_name, _get_promote_fn, fn_args, proto_list_key) + @parameterized.named_parameters( + *[ + dict( + testcase_name="promote_1", + fn_name="promote_1", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + s2t.path.Path(["child_1", "int_values_2"]), + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="promote_2", + fn_name="promote_2", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + s2t.path.Path(["child_1", "child_2", "int_values_3"]), + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="promote_3", + fn_name="promote_3", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + s2t.path.Path(["child_1", "child_2", "child_3", "int_values_4"]), + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="promote_4", + fn_name="promote_4", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + s2t.path.Path( + ["child_1", "child_2", "child_3", "child_4", "int_values_5"] + ), + ], + proto_list_key="deep_protos", + ), + ] + ) + def test_promote(self, fn_name, fn_args, proto_list_key): + self.run_benchmarks(fn_name, _get_promote_fn, fn_args, proto_list_key) class BroadcastBenchmarks(struct2tensor_benchmark_util.ProtoDataBenchmarks): - """Benchmarks for broadcasting fields.""" + """Benchmarks for broadcasting fields.""" - @parameterized.named_parameters(*[ - dict( - testcase_name="broadcast_2", - fn_name="broadcast_2", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - s2t.path.Path(["int_values_1"]), "child_1" - ], - proto_list_key="deep_protos"), - dict( - testcase_name="broadcast_3", - fn_name="broadcast_3", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - s2t.path.Path(["child_1", "int_values_2"]), "child_2" - ], - proto_list_key="deep_protos"), - dict( - testcase_name="broadcast_4", - fn_name="broadcast_4", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - s2t.path.Path(["child_1", "child_2", "int_values_3"]), "child_3" - ], - proto_list_key="deep_protos"), - dict( - testcase_name="broadcast_5", - fn_name="broadcast_5", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - s2t.path.Path(["child_1", "child_2", "child_3", "int_values_4"]), - "child_4" - ], - proto_list_key="deep_protos"), - ]) - def test_broadcast(self, fn_name, fn_args, proto_list_key): - self.run_benchmarks(fn_name, _get_broadcast_fn, fn_args, proto_list_key) + @parameterized.named_parameters( + *[ + dict( + testcase_name="broadcast_2", + fn_name="broadcast_2", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + s2t.path.Path(["int_values_1"]), + "child_1", + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="broadcast_3", + fn_name="broadcast_3", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + s2t.path.Path(["child_1", "int_values_2"]), + "child_2", + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="broadcast_4", + fn_name="broadcast_4", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + s2t.path.Path(["child_1", "child_2", "int_values_3"]), + "child_3", + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="broadcast_5", + fn_name="broadcast_5", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + s2t.path.Path(["child_1", "child_2", "child_3", "int_values_4"]), + "child_4", + ], + proto_list_key="deep_protos", + ), + ] + ) + def test_broadcast(self, fn_name, fn_args, proto_list_key): + self.run_benchmarks(fn_name, _get_broadcast_fn, fn_args, proto_list_key) class RerootBenchmarks(struct2tensor_benchmark_util.ProtoDataBenchmarks): - """Benchmarks for rerooting fields.""" + """Benchmarks for rerooting fields.""" - @parameterized.named_parameters(*[ - dict( - testcase_name="reroot_1", - fn_name="reroot_1", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - s2t.path.Path(["child_1"]), [s2t.path.Path(["int_values_2"])] - ], - proto_list_key="deep_protos"), - dict( - testcase_name="reroot_2", - fn_name="reroot_2", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - s2t.path.Path(["child_1", "child_2"]), - [s2t.path.Path(["int_values_3"])] - ], - proto_list_key="deep_protos"), - dict( - testcase_name="reroot_3", - fn_name="reroot_3", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - s2t.path.Path(["child_1", "child_2", "child_3"]), - [s2t.path.Path(["int_values_4"])] - ], - proto_list_key="deep_protos"), - dict( - testcase_name="reroot_4", - fn_name="reroot_4", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - s2t.path.Path(["child_1", "child_2", "child_3", "child_4"]), - [s2t.path.Path(["int_values_5"])] - ], - proto_list_key="deep_protos"), - ]) - def test_reroot(self, fn_name, fn_args, proto_list_key): - self.run_benchmarks(fn_name, _get_reroot_fn, fn_args, proto_list_key) + @parameterized.named_parameters( + *[ + dict( + testcase_name="reroot_1", + fn_name="reroot_1", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + s2t.path.Path(["child_1"]), + [s2t.path.Path(["int_values_2"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="reroot_2", + fn_name="reroot_2", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + s2t.path.Path(["child_1", "child_2"]), + [s2t.path.Path(["int_values_3"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="reroot_3", + fn_name="reroot_3", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + s2t.path.Path(["child_1", "child_2", "child_3"]), + [s2t.path.Path(["int_values_4"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="reroot_4", + fn_name="reroot_4", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + s2t.path.Path(["child_1", "child_2", "child_3", "child_4"]), + [s2t.path.Path(["int_values_5"])], + ], + proto_list_key="deep_protos", + ), + ] + ) + def test_reroot(self, fn_name, fn_args, proto_list_key): + self.run_benchmarks(fn_name, _get_reroot_fn, fn_args, proto_list_key) -class PrensorToTensorBenchmarks( - struct2tensor_benchmark_util.ProtoDataBenchmarks): - """Benchmarks for converting prensor to tensors.""" +class PrensorToTensorBenchmarks(struct2tensor_benchmark_util.ProtoDataBenchmarks): + """Benchmarks for converting prensor to tensors.""" - # pylint: disable=g-complex-comprehension - @parameterized.named_parameters(*[ - dict( - testcase_name=f"flat_to_dense_{n}_features", - fn_name=f"flat_to_dense_{n}_features", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [ - s2t.path.Path([f"int_values_{i + 1}"]) - for i in range(n) - ] - ], - proto_list_key="flat_protos") - for n in [1, 2, 3, 4, 5] - ] + [ - dict( - testcase_name="flat_to_dense_1_feature_100_values", - fn_name="flat_to_dense_1_feature_100_values", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [ - s2t.path.Path(["int_values_1"]) - ] - ], - proto_list_key="flat_protos_1_feature_100_values"), - dict( - testcase_name="flat_to_dense_100_features", - fn_name="flat_to_dense_100_features", - fn_args=[ - benchmark_pb2.FlatProto100.DESCRIPTOR, - [ - s2t.path.Path([f"int_values_{i}"]) - for i in range(1, 101) - ] - ], - proto_list_key="flat_protos_100_features"), - dict( - testcase_name="flat_to_dense_100_features_100_values", - fn_name="flat_to_dense_100_features_100_values", - fn_args=[ - benchmark_pb2.FlatProto100.DESCRIPTOR, - [ - s2t.path.Path([f"int_values_{i}"]) - for i in range(1, 101) - ] - ], - proto_list_key="flat_protos_100_features_100_values"), - dict( - testcase_name="deep_to_dense_1", - fn_name="deep_to_dense_1", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [s2t.path.Path(["int_values_1"])] - ], - proto_list_key="deep_protos"), - dict( - testcase_name="deep_to_dense_2", - fn_name="deep_to_dense_2", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [s2t.path.Path(["child_1", "int_values_2"])] - ], - proto_list_key="deep_protos"), - dict( - testcase_name="deep_to_dense_3", - fn_name="deep_to_dense_3", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [s2t.path.Path(["child_1", "child_2", "int_values_3"])] - ], - proto_list_key="deep_protos"), - dict( - testcase_name="deep_to_dense_4", - fn_name="deep_to_dense_4", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [ - s2t.path.Path( - ["child_1", "child_2", "child_3", "int_values_4"]) - ] - ], - proto_list_key="deep_protos"), - dict( - testcase_name="deep_to_dense_5", - fn_name="deep_to_dense_5", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [ - s2t.path.Path([ - "child_1", "child_2", "child_3", "child_4", "int_values_5" - ]) - ] - ], - proto_list_key="deep_protos"), - ]) - # pylint: enable=g-complex-comprehension - def test_to_dense(self, fn_name, fn_args, proto_list_key): - """This benchmark converts prensors to dense tensors.""" - self.run_benchmarks(fn_name, _get_prensor_to_dense_tensor_fn, fn_args, - proto_list_key) + # pylint: disable=g-complex-comprehension + @parameterized.named_parameters( + *[ + dict( + testcase_name=f"flat_to_dense_{n}_features", + fn_name=f"flat_to_dense_{n}_features", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [s2t.path.Path([f"int_values_{i + 1}"]) for i in range(n)], + ], + proto_list_key="flat_protos", + ) + for n in [1, 2, 3, 4, 5] + ] + + [ + dict( + testcase_name="flat_to_dense_1_feature_100_values", + fn_name="flat_to_dense_1_feature_100_values", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [s2t.path.Path(["int_values_1"])], + ], + proto_list_key="flat_protos_1_feature_100_values", + ), + dict( + testcase_name="flat_to_dense_100_features", + fn_name="flat_to_dense_100_features", + fn_args=[ + benchmark_pb2.FlatProto100.DESCRIPTOR, + [s2t.path.Path([f"int_values_{i}"]) for i in range(1, 101)], + ], + proto_list_key="flat_protos_100_features", + ), + dict( + testcase_name="flat_to_dense_100_features_100_values", + fn_name="flat_to_dense_100_features_100_values", + fn_args=[ + benchmark_pb2.FlatProto100.DESCRIPTOR, + [s2t.path.Path([f"int_values_{i}"]) for i in range(1, 101)], + ], + proto_list_key="flat_protos_100_features_100_values", + ), + dict( + testcase_name="deep_to_dense_1", + fn_name="deep_to_dense_1", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["int_values_1"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="deep_to_dense_2", + fn_name="deep_to_dense_2", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["child_1", "int_values_2"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="deep_to_dense_3", + fn_name="deep_to_dense_3", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["child_1", "child_2", "int_values_3"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="deep_to_dense_4", + fn_name="deep_to_dense_4", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["child_1", "child_2", "child_3", "int_values_4"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="deep_to_dense_5", + fn_name="deep_to_dense_5", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [ + s2t.path.Path( + ["child_1", "child_2", "child_3", "child_4", "int_values_5"] + ) + ], + ], + proto_list_key="deep_protos", + ), + ] + ) + # pylint: enable=g-complex-comprehension + def test_to_dense(self, fn_name, fn_args, proto_list_key): + """This benchmark converts prensors to dense tensors.""" + self.run_benchmarks( + fn_name, _get_prensor_to_dense_tensor_fn, fn_args, proto_list_key + ) - # pylint: disable=g-complex-comprehension - @parameterized.named_parameters(*[ - dict( - testcase_name=f"flat_to_ragged_{n}_features", - fn_name=f"flat_to_ragged_{n}_features", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [ - s2t.path.Path([f"int_values_{i + 1}"]) - for i in range(n) - ] - ], - proto_list_key="flat_protos") - for n in [1, 2, 3, 4, 5] - ] + [ - dict( - testcase_name="flat_to_ragged_1_feature_100_values", - fn_name="flat_to_ragged_1_feature_100_values", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [ - s2t.path.Path(["int_values_1"]) - ] - ], - proto_list_key="flat_protos_1_feature_100_values"), - dict( - testcase_name="flat_to_ragged_100_features", - fn_name="flat_to_ragged_100_features", - fn_args=[ - benchmark_pb2.FlatProto100.DESCRIPTOR, - [ - s2t.path.Path([f"int_values_{i}"]) - for i in range(1, 101) - ] - ], - proto_list_key="flat_protos_100_features"), - dict( - testcase_name="flat_to_ragged_100_features_100_values", - fn_name="flat_to_ragged_100_features_100_values", - fn_args=[ - benchmark_pb2.FlatProto100.DESCRIPTOR, - [ - s2t.path.Path([f"int_values_{i}"]) - for i in range(1, 101) - ] - ], - proto_list_key="flat_protos_100_features_100_values"), - dict( - testcase_name="deep_to_ragged_1", - fn_name="deep_to_ragged_1", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [s2t.path.Path(["int_values_1"])] - ], - proto_list_key="deep_protos"), - dict( - testcase_name="deep_to_ragged_2", - fn_name="deep_to_ragged_2", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [s2t.path.Path(["child_1", "int_values_2"])] - ], - proto_list_key="deep_protos"), - dict( - testcase_name="deep_to_ragged_3", - fn_name="deep_to_ragged_3", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [s2t.path.Path(["child_1", "child_2", "int_values_3"])] - ], - proto_list_key="deep_protos"), - dict( - testcase_name="deep_to_ragged_4", - fn_name="deep_to_ragged_4", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [ - s2t.path.Path( - ["child_1", "child_2", "child_3", "int_values_4"]) - ] - ], - proto_list_key="deep_protos"), - dict( - testcase_name="deep_to_ragged_5", - fn_name="deep_to_ragged_5", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [ - s2t.path.Path([ - "child_1", "child_2", "child_3", "child_4", "int_values_5" - ]) - ] - ], - proto_list_key="deep_protos"), - ]) - # pylint: enable=g-complex-comprehension - def test_to_ragged(self, fn_name, fn_args, proto_list_key): - """This benchmark converts prensors to ragged tensors.""" - self.run_benchmarks(fn_name, _get_prensor_to_ragged_tensor_fn, fn_args, - proto_list_key) + # pylint: disable=g-complex-comprehension + @parameterized.named_parameters( + *[ + dict( + testcase_name=f"flat_to_ragged_{n}_features", + fn_name=f"flat_to_ragged_{n}_features", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [s2t.path.Path([f"int_values_{i + 1}"]) for i in range(n)], + ], + proto_list_key="flat_protos", + ) + for n in [1, 2, 3, 4, 5] + ] + + [ + dict( + testcase_name="flat_to_ragged_1_feature_100_values", + fn_name="flat_to_ragged_1_feature_100_values", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [s2t.path.Path(["int_values_1"])], + ], + proto_list_key="flat_protos_1_feature_100_values", + ), + dict( + testcase_name="flat_to_ragged_100_features", + fn_name="flat_to_ragged_100_features", + fn_args=[ + benchmark_pb2.FlatProto100.DESCRIPTOR, + [s2t.path.Path([f"int_values_{i}"]) for i in range(1, 101)], + ], + proto_list_key="flat_protos_100_features", + ), + dict( + testcase_name="flat_to_ragged_100_features_100_values", + fn_name="flat_to_ragged_100_features_100_values", + fn_args=[ + benchmark_pb2.FlatProto100.DESCRIPTOR, + [s2t.path.Path([f"int_values_{i}"]) for i in range(1, 101)], + ], + proto_list_key="flat_protos_100_features_100_values", + ), + dict( + testcase_name="deep_to_ragged_1", + fn_name="deep_to_ragged_1", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["int_values_1"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="deep_to_ragged_2", + fn_name="deep_to_ragged_2", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["child_1", "int_values_2"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="deep_to_ragged_3", + fn_name="deep_to_ragged_3", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["child_1", "child_2", "int_values_3"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="deep_to_ragged_4", + fn_name="deep_to_ragged_4", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["child_1", "child_2", "child_3", "int_values_4"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="deep_to_ragged_5", + fn_name="deep_to_ragged_5", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [ + s2t.path.Path( + ["child_1", "child_2", "child_3", "child_4", "int_values_5"] + ) + ], + ], + proto_list_key="deep_protos", + ), + ] + ) + # pylint: enable=g-complex-comprehension + def test_to_ragged(self, fn_name, fn_args, proto_list_key): + """This benchmark converts prensors to ragged tensors.""" + self.run_benchmarks( + fn_name, _get_prensor_to_ragged_tensor_fn, fn_args, proto_list_key + ) - # pylint: disable=g-complex-comprehension - @parameterized.named_parameters(*[ - dict( - testcase_name=f"flat_to_sparse_{n}_features", - fn_name=f"flat_to_sparse_{n}_features", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [ - s2t.path.Path([f"int_values_{i + 1}"]) - for i in range(n) - ] - ], - proto_list_key="flat_protos") - for n in [1, 2, 3, 4, 5] - ] + [ - dict( - testcase_name="flat_to_sparse_1_feature_100_values", - fn_name="flat_to_sparse_1_feature_100_values", - fn_args=[ - benchmark_pb2.FlatProto.DESCRIPTOR, - [ - s2t.path.Path(["int_values_1"]) - ] - ], - proto_list_key="flat_protos_1_feature_100_values"), - dict( - testcase_name="flat_to_sparse_100_features", - fn_name="flat_to_sparse_100_features", - fn_args=[ - benchmark_pb2.FlatProto100.DESCRIPTOR, - [ - s2t.path.Path([f"int_values_{i}"]) - for i in range(1, 101) - ] - ], - proto_list_key="flat_protos_100_features"), - dict( - testcase_name="flat_to_sparse_100_features_100_values", - fn_name="flat_to_sparse_100_features_100_values", - fn_args=[ - benchmark_pb2.FlatProto100.DESCRIPTOR, - [ - s2t.path.Path([f"int_values_{i}"]) - for i in range(1, 101) - ] - ], - proto_list_key="flat_protos_100_features_100_values"), - dict( - testcase_name="deep_to_sparse_1", - fn_name="deep_to_sparse_1", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [s2t.path.Path(["int_values_1"])] - ], - proto_list_key="deep_protos"), - dict( - testcase_name="deep_to_sparse_2", - fn_name="deep_to_sparse_2", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [s2t.path.Path(["child_1", "int_values_2"])] - ], - proto_list_key="deep_protos"), - dict( - testcase_name="deep_to_sparse_3", - fn_name="deep_to_sparse_3", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [s2t.path.Path(["child_1", "child_2", "int_values_3"])] - ], - proto_list_key="deep_protos"), - dict( - testcase_name="deep_to_sparse_4", - fn_name="deep_to_sparse_4", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [ - s2t.path.Path( - ["child_1", "child_2", "child_3", "int_values_4"]) - ] - ], - proto_list_key="deep_protos"), - dict( - testcase_name="deep_to_sparse_5", - fn_name="deep_to_sparse_5", - fn_args=[ - benchmark_pb2.DeepProto.DESCRIPTOR, - [ - s2t.path.Path([ - "child_1", "child_2", "child_3", "child_4", "int_values_5" - ]) - ] - ], - proto_list_key="deep_protos"), - ]) - # pylint: enable=g-complex-comprehension - def test_to_sparse(self, fn_name, fn_args, proto_list_key): - """This benchmark converts prensors to sparse tensors.""" - self.run_benchmarks(fn_name, _get_prensor_to_sparse_tensor_fn, fn_args, - proto_list_key) + # pylint: disable=g-complex-comprehension + @parameterized.named_parameters( + *[ + dict( + testcase_name=f"flat_to_sparse_{n}_features", + fn_name=f"flat_to_sparse_{n}_features", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [s2t.path.Path([f"int_values_{i + 1}"]) for i in range(n)], + ], + proto_list_key="flat_protos", + ) + for n in [1, 2, 3, 4, 5] + ] + + [ + dict( + testcase_name="flat_to_sparse_1_feature_100_values", + fn_name="flat_to_sparse_1_feature_100_values", + fn_args=[ + benchmark_pb2.FlatProto.DESCRIPTOR, + [s2t.path.Path(["int_values_1"])], + ], + proto_list_key="flat_protos_1_feature_100_values", + ), + dict( + testcase_name="flat_to_sparse_100_features", + fn_name="flat_to_sparse_100_features", + fn_args=[ + benchmark_pb2.FlatProto100.DESCRIPTOR, + [s2t.path.Path([f"int_values_{i}"]) for i in range(1, 101)], + ], + proto_list_key="flat_protos_100_features", + ), + dict( + testcase_name="flat_to_sparse_100_features_100_values", + fn_name="flat_to_sparse_100_features_100_values", + fn_args=[ + benchmark_pb2.FlatProto100.DESCRIPTOR, + [s2t.path.Path([f"int_values_{i}"]) for i in range(1, 101)], + ], + proto_list_key="flat_protos_100_features_100_values", + ), + dict( + testcase_name="deep_to_sparse_1", + fn_name="deep_to_sparse_1", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["int_values_1"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="deep_to_sparse_2", + fn_name="deep_to_sparse_2", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["child_1", "int_values_2"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="deep_to_sparse_3", + fn_name="deep_to_sparse_3", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["child_1", "child_2", "int_values_3"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="deep_to_sparse_4", + fn_name="deep_to_sparse_4", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [s2t.path.Path(["child_1", "child_2", "child_3", "int_values_4"])], + ], + proto_list_key="deep_protos", + ), + dict( + testcase_name="deep_to_sparse_5", + fn_name="deep_to_sparse_5", + fn_args=[ + benchmark_pb2.DeepProto.DESCRIPTOR, + [ + s2t.path.Path( + ["child_1", "child_2", "child_3", "child_4", "int_values_5"] + ) + ], + ], + proto_list_key="deep_protos", + ), + ] + ) + # pylint: enable=g-complex-comprehension + def test_to_sparse(self, fn_name, fn_args, proto_list_key): + """This benchmark converts prensors to sparse tensors.""" + self.run_benchmarks( + fn_name, _get_prensor_to_sparse_tensor_fn, fn_args, proto_list_key + ) class TfExampleBenchmarks(struct2tensor_benchmark_util.ProtoDataBenchmarks): - """Benchmarks for converting tf.example to tensors.""" + """Benchmarks for converting tf.example to tensors.""" - # pylint: disable=g-complex-comprehension - @parameterized.named_parameters(*[ - dict( - testcase_name=f"to_fixed_len_feature_{n}_features_1_value", - fn_name=f"tf_example_to_fixed_len_feature_{n}", - fn_args=[{ - "int_values_1": tf.io.FixedLenFeature(shape=[1], dtype=tf.int64) - }], - proto_list_key="tf_examples") for n in [1, 2, 3, 4, 5] - ] + [ - dict( - testcase_name="to_fixed_len_feature_1_feature_100_value", - fn_name="tf_example_to_fixed_len_feature_1_feature_100_value", - fn_args=[{ - "int_values_1": tf.io.FixedLenFeature( - shape=[100], dtype=tf.int64) - }], - proto_list_key="tf_examples_1_feature_100_values"), - dict( - testcase_name="to_fixed_len_feature_100_features_1_value", - fn_name="tf_example_to_fixed_len_feature_100_features_1_value", - fn_args=[{ - f"int_values_{i}": tf.io.FixedLenFeature( - shape=[1], dtype=tf.int64) for i in range(1, 101) - }], - proto_list_key="tf_examples_100_features"), - dict( - testcase_name="to_fixed_len_feature_100_features_100_values", - fn_name="tf_example_to_fixed_len_feature_100_features_100_values", - fn_args=[{ - f"int_values_{i}": tf.io.FixedLenFeature( - shape=[100], dtype=tf.int64) for i in range(1, 101) - }], - proto_list_key="tf_examples_100_features_100_values") - ]) - # pylint: enable=g-complex-comprehension - def test_parse_tf_example(self, fn_name, fn_args, proto_list_key): - self.run_benchmarks(fn_name, _get_parse_tf_example_fn, fn_args, - proto_list_key) + # pylint: disable=g-complex-comprehension + @parameterized.named_parameters( + *[ + dict( + testcase_name=f"to_fixed_len_feature_{n}_features_1_value", + fn_name=f"tf_example_to_fixed_len_feature_{n}", + fn_args=[ + {"int_values_1": tf.io.FixedLenFeature(shape=[1], dtype=tf.int64)} + ], + proto_list_key="tf_examples", + ) + for n in [1, 2, 3, 4, 5] + ] + + [ + dict( + testcase_name="to_fixed_len_feature_1_feature_100_value", + fn_name="tf_example_to_fixed_len_feature_1_feature_100_value", + fn_args=[ + {"int_values_1": tf.io.FixedLenFeature(shape=[100], dtype=tf.int64)} + ], + proto_list_key="tf_examples_1_feature_100_values", + ), + dict( + testcase_name="to_fixed_len_feature_100_features_1_value", + fn_name="tf_example_to_fixed_len_feature_100_features_1_value", + fn_args=[ + { + f"int_values_{i}": tf.io.FixedLenFeature( + shape=[1], dtype=tf.int64 + ) + for i in range(1, 101) + } + ], + proto_list_key="tf_examples_100_features", + ), + dict( + testcase_name="to_fixed_len_feature_100_features_100_values", + fn_name="tf_example_to_fixed_len_feature_100_features_100_values", + fn_args=[ + { + f"int_values_{i}": tf.io.FixedLenFeature( + shape=[100], dtype=tf.int64 + ) + for i in range(1, 101) + } + ], + proto_list_key="tf_examples_100_features_100_values", + ), + ] + ) + # pylint: enable=g-complex-comprehension + def test_parse_tf_example(self, fn_name, fn_args, proto_list_key): + self.run_benchmarks(fn_name, _get_parse_tf_example_fn, fn_args, proto_list_key) - @parameterized.named_parameters(*[ - dict( - testcase_name=f"to_var_len_feature_{n}_features_1_value", - fn_name=f"tf_example_to_var_len_feature_{n}", - fn_args=[{ - "int_values_1": tf.io.VarLenFeature(dtype=tf.int64) - }], - proto_list_key="tf_examples") for n in [1, 2, 3, 4, 5] - ] + [ - dict( - testcase_name="to_var_len_feature_1_feature_100_value", - fn_name="tf_example_to_var_len_feature_1_feature_100_value", - fn_args=[{ - "int_values_1": tf.io.VarLenFeature(dtype=tf.int64) - }], - proto_list_key="tf_examples_1_feature_100_values"), - dict( - testcase_name="to_var_len_feature_100_features_1_value", - fn_name="tf_example_to_var_len_feature_100_features_1_value", - fn_args=[{ - f"int_values_{i}": tf.io.VarLenFeature(dtype=tf.int64) - for i in range(1, 101) - }], - proto_list_key="tf_examples_100_features"), - dict( - testcase_name="to_var_len_feature_100_features_100_values", - fn_name="tf_example_to_var_len_feature_100_features_100_values", - fn_args=[{ - f"int_values_{i}": tf.io.VarLenFeature(dtype=tf.int64) - for i in range(1, 101) - }], - proto_list_key="tf_examples_100_features_100_values") - ]) - def test_parse_tf_example_to_sparse(self, fn_name, fn_args, proto_list_key): - self.run_benchmarks(fn_name, _get_parse_tf_example_to_sparse_fn, fn_args, - proto_list_key) + @parameterized.named_parameters( + *[ + dict( + testcase_name=f"to_var_len_feature_{n}_features_1_value", + fn_name=f"tf_example_to_var_len_feature_{n}", + fn_args=[{"int_values_1": tf.io.VarLenFeature(dtype=tf.int64)}], + proto_list_key="tf_examples", + ) + for n in [1, 2, 3, 4, 5] + ] + + [ + dict( + testcase_name="to_var_len_feature_1_feature_100_value", + fn_name="tf_example_to_var_len_feature_1_feature_100_value", + fn_args=[{"int_values_1": tf.io.VarLenFeature(dtype=tf.int64)}], + proto_list_key="tf_examples_1_feature_100_values", + ), + dict( + testcase_name="to_var_len_feature_100_features_1_value", + fn_name="tf_example_to_var_len_feature_100_features_1_value", + fn_args=[ + { + f"int_values_{i}": tf.io.VarLenFeature(dtype=tf.int64) + for i in range(1, 101) + } + ], + proto_list_key="tf_examples_100_features", + ), + dict( + testcase_name="to_var_len_feature_100_features_100_values", + fn_name="tf_example_to_var_len_feature_100_features_100_values", + fn_args=[ + { + f"int_values_{i}": tf.io.VarLenFeature(dtype=tf.int64) + for i in range(1, 101) + } + ], + proto_list_key="tf_examples_100_features_100_values", + ), + ] + ) + def test_parse_tf_example_to_sparse(self, fn_name, fn_args, proto_list_key): + self.run_benchmarks( + fn_name, _get_parse_tf_example_to_sparse_fn, fn_args, proto_list_key + ) - @parameterized.named_parameters(*[ - dict( - testcase_name=f"to_ragged_feature_{n}_features_1_value", - fn_name=f"tf_example_to_ragged_feature_{n}", - fn_args=[{ - "int_values_1": - tf.io.RaggedFeature(value_key="int_values_1", dtype=tf.int64) - }], - proto_list_key="tf_examples") for n in [1, 2, 3, 4, 5] - ] + [ - dict( - testcase_name="to_ragged_feature_1_feature_100_value", - fn_name="tf_example_to_ragged_feature_1_feature_100_value", - fn_args=[{ - "int_values_1": - tf.io.RaggedFeature(value_key="int_values_1", dtype=tf.int64) - }], - proto_list_key="tf_examples_1_feature_100_values"), - dict( - testcase_name="to_ragged_feature_100_features_1_value", - fn_name="tf_example_to_ragged_feature_100_features_1_value", - fn_args=[{ - f"int_values_{i}": tf.io.RaggedFeature( - value_key=f"int_values_{i}", dtype=tf.int64) - for i in range(1, 101) - }], - proto_list_key="tf_examples_100_features"), - dict( - testcase_name="to_ragged_feature_100_features_100_values", - fn_name="tf_example_to_ragged_feature_100_features_100_values", - fn_args=[{ - f"int_values_{i}": tf.io.RaggedFeature( - value_key=f"int_values_{i}", dtype=tf.int64) - for i in range(1, 101) - }], - proto_list_key="tf_examples_100_features_100_values") - ]) - def test_parse_tf_example_to_ragged(self, fn_name, fn_args, proto_list_key): - self.run_benchmarks(fn_name, _get_parse_tf_example_fn, fn_args, - proto_list_key) + @parameterized.named_parameters( + *[ + dict( + testcase_name=f"to_ragged_feature_{n}_features_1_value", + fn_name=f"tf_example_to_ragged_feature_{n}", + fn_args=[ + { + "int_values_1": tf.io.RaggedFeature( + value_key="int_values_1", dtype=tf.int64 + ) + } + ], + proto_list_key="tf_examples", + ) + for n in [1, 2, 3, 4, 5] + ] + + [ + dict( + testcase_name="to_ragged_feature_1_feature_100_value", + fn_name="tf_example_to_ragged_feature_1_feature_100_value", + fn_args=[ + { + "int_values_1": tf.io.RaggedFeature( + value_key="int_values_1", dtype=tf.int64 + ) + } + ], + proto_list_key="tf_examples_1_feature_100_values", + ), + dict( + testcase_name="to_ragged_feature_100_features_1_value", + fn_name="tf_example_to_ragged_feature_100_features_1_value", + fn_args=[ + { + f"int_values_{i}": tf.io.RaggedFeature( + value_key=f"int_values_{i}", dtype=tf.int64 + ) + for i in range(1, 101) + } + ], + proto_list_key="tf_examples_100_features", + ), + dict( + testcase_name="to_ragged_feature_100_features_100_values", + fn_name="tf_example_to_ragged_feature_100_features_100_values", + fn_args=[ + { + f"int_values_{i}": tf.io.RaggedFeature( + value_key=f"int_values_{i}", dtype=tf.int64 + ) + for i in range(1, 101) + } + ], + proto_list_key="tf_examples_100_features_100_values", + ), + ] + ) + def test_parse_tf_example_to_ragged(self, fn_name, fn_args, proto_list_key): + self.run_benchmarks(fn_name, _get_parse_tf_example_fn, fn_args, proto_list_key) def _get_project_fn(session, proto_descriptor, path): - """Returns a callable that projects the path, and creates ragged tensors.""" - protos = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) - expr = s2t.expression_impl.proto.create_expression_from_proto( - protos, proto_descriptor) - expr = expr.project(path) - [prensor] = s2t.calculate.calculate_prensors( - [expr], options=s2t.calculate_options.get_options_with_minimal_checks()) - rt = prensor.get_ragged_tensors() - with tf.control_dependencies(rt.values()): - x = tf.constant(1) - return session.make_callable(x, feed_list=[protos]) + """Returns a callable that projects the path, and creates ragged tensors.""" + protos = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) + expr = s2t.expression_impl.proto.create_expression_from_proto( + protos, proto_descriptor + ) + expr = expr.project(path) + [prensor] = s2t.calculate.calculate_prensors( + [expr], options=s2t.calculate_options.get_options_with_minimal_checks() + ) + rt = prensor.get_ragged_tensors() + with tf.control_dependencies(rt.values()): + x = tf.constant(1) + return session.make_callable(x, feed_list=[protos]) def _get_promote_fn(session, proto_descriptor, path_to_promote): - """Returns a callable that promotes a field, and creates ragged tensors.""" - protos = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) - expr = s2t.expression_impl.proto.create_expression_from_proto( - protos, proto_descriptor).promote(path_to_promote, "new_child").project( - [path_to_promote.prefix(-2).concat(s2t.path.Path(["new_child"]))]) - [prensor] = s2t.calculate.calculate_prensors( - [expr], options=s2t.calculate_options.get_options_with_minimal_checks()) - rt = prensor.get_ragged_tensors() - with tf.control_dependencies(rt.values()): - x = tf.constant(1) - return session.make_callable(x, feed_list=[protos]) + """Returns a callable that promotes a field, and creates ragged tensors.""" + protos = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) + expr = ( + s2t.expression_impl.proto.create_expression_from_proto(protos, proto_descriptor) + .promote(path_to_promote, "new_child") + .project([path_to_promote.prefix(-2).concat(s2t.path.Path(["new_child"]))]) + ) + [prensor] = s2t.calculate.calculate_prensors( + [expr], options=s2t.calculate_options.get_options_with_minimal_checks() + ) + rt = prensor.get_ragged_tensors() + with tf.control_dependencies(rt.values()): + x = tf.constant(1) + return session.make_callable(x, feed_list=[protos]) def _get_broadcast_fn(session, proto_descriptor, path_to_broadcast, sibling): - """Returns a callable that broadcasts a field, and creates ragged tensors.""" - protos = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) - expr = s2t.expression_impl.proto.create_expression_from_proto( - protos, - proto_descriptor).broadcast(path_to_broadcast, sibling, - "new_child").project([ - path_to_broadcast.get_parent().concat( - s2t.path.Path([sibling, "new_child"])) - ]) - [prensor] = s2t.calculate.calculate_prensors( - [expr], options=s2t.calculate_options.get_options_with_minimal_checks()) - rt = prensor.get_ragged_tensors() - with tf.control_dependencies(rt.values()): - x = tf.constant(1) - return session.make_callable(x, feed_list=[protos]) + """Returns a callable that broadcasts a field, and creates ragged tensors.""" + protos = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) + expr = ( + s2t.expression_impl.proto.create_expression_from_proto(protos, proto_descriptor) + .broadcast(path_to_broadcast, sibling, "new_child") + .project( + [ + path_to_broadcast.get_parent().concat( + s2t.path.Path([sibling, "new_child"]) + ) + ] + ) + ) + [prensor] = s2t.calculate.calculate_prensors( + [expr], options=s2t.calculate_options.get_options_with_minimal_checks() + ) + rt = prensor.get_ragged_tensors() + with tf.control_dependencies(rt.values()): + x = tf.constant(1) + return session.make_callable(x, feed_list=[protos]) def _get_reroot_fn(session, proto_descriptor, path_to_reroot, path_to_project): - """Returns a callable that reroots a field, and creates ragged tensors.""" - protos = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) - expr = s2t.expression_impl.proto.create_expression_from_proto( - protos, proto_descriptor).reroot(path_to_reroot).project(path_to_project) - [prensor] = s2t.calculate.calculate_prensors( - [expr], options=s2t.calculate_options.get_options_with_minimal_checks()) - rt = prensor.get_ragged_tensors() - with tf.control_dependencies(rt.values()): - x = tf.constant(1) - return session.make_callable(x, feed_list=[protos]) + """Returns a callable that reroots a field, and creates ragged tensors.""" + protos = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) + expr = ( + s2t.expression_impl.proto.create_expression_from_proto(protos, proto_descriptor) + .reroot(path_to_reroot) + .project(path_to_project) + ) + [prensor] = s2t.calculate.calculate_prensors( + [expr], options=s2t.calculate_options.get_options_with_minimal_checks() + ) + rt = prensor.get_ragged_tensors() + with tf.control_dependencies(rt.values()): + x = tf.constant(1) + return session.make_callable(x, feed_list=[protos]) def _get_prensor_to_dense_tensor_fn(session, proto_descriptor, path): - """Returns a callable that projects a field, and creates ragged tensors.""" - protos = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) - expr = s2t.expression_impl.proto.create_expression_from_proto( - protos, proto_descriptor) - expr = expr.project(path) - [prensor] = s2t.calculate.calculate_prensors( - [expr], options=s2t.calculate_options.get_options_with_minimal_checks()) - ragged_tensors = prensor.get_ragged_tensors() - dt = [rt.to_tensor() for rt in ragged_tensors.values()] - with tf.control_dependencies(dt): - x = tf.constant(1) - return session.make_callable(x, feed_list=[protos]) + """Returns a callable that projects a field, and creates ragged tensors.""" + protos = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) + expr = s2t.expression_impl.proto.create_expression_from_proto( + protos, proto_descriptor + ) + expr = expr.project(path) + [prensor] = s2t.calculate.calculate_prensors( + [expr], options=s2t.calculate_options.get_options_with_minimal_checks() + ) + ragged_tensors = prensor.get_ragged_tensors() + dt = [rt.to_tensor() for rt in ragged_tensors.values()] + with tf.control_dependencies(dt): + x = tf.constant(1) + return session.make_callable(x, feed_list=[protos]) def _get_prensor_to_ragged_tensor_fn(session, proto_descriptor, path): - """Returns a callable that projects a field, and creates ragged tensors.""" - protos = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) - expr = s2t.expression_impl.proto.create_expression_from_proto( - protos, proto_descriptor) - expr = expr.project(path) - [prensor] = s2t.calculate.calculate_prensors( - [expr], options=s2t.calculate_options.get_options_with_minimal_checks()) - rt = prensor.get_ragged_tensors() - with tf.control_dependencies(rt.values()): - x = tf.constant(1) - return session.make_callable(x, feed_list=[protos]) + """Returns a callable that projects a field, and creates ragged tensors.""" + protos = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) + expr = s2t.expression_impl.proto.create_expression_from_proto( + protos, proto_descriptor + ) + expr = expr.project(path) + [prensor] = s2t.calculate.calculate_prensors( + [expr], options=s2t.calculate_options.get_options_with_minimal_checks() + ) + rt = prensor.get_ragged_tensors() + with tf.control_dependencies(rt.values()): + x = tf.constant(1) + return session.make_callable(x, feed_list=[protos]) def _get_prensor_to_sparse_tensor_fn(session, proto_descriptor, path): - """Returns a callable that projects a field, and creates sparse tensors.""" - protos = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) - expr = s2t.expression_impl.proto.create_expression_from_proto( - protos, proto_descriptor) - expr = expr.project(path) - [prensor] = s2t.calculate.calculate_prensors( - [expr], options=s2t.calculate_options.get_options_with_minimal_checks()) - st = prensor.get_sparse_tensors() - with tf.control_dependencies( - [sparse_tensor.indices for sparse_tensor in st.values()]): - x = tf.constant(1) - return session.make_callable(x, feed_list=[protos]) + """Returns a callable that projects a field, and creates sparse tensors.""" + protos = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) + expr = s2t.expression_impl.proto.create_expression_from_proto( + protos, proto_descriptor + ) + expr = expr.project(path) + [prensor] = s2t.calculate.calculate_prensors( + [expr], options=s2t.calculate_options.get_options_with_minimal_checks() + ) + st = prensor.get_sparse_tensors() + with tf.control_dependencies( + [sparse_tensor.indices for sparse_tensor in st.values()] + ): + x = tf.constant(1) + return session.make_callable(x, feed_list=[protos]) def _get_parse_tf_example_fn(session, features): - protos = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) + protos = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) - tensor_map = tf.io.parse_example(protos, features) - with tf.control_dependencies(tensor_map.values()): - x = tf.constant(1) + tensor_map = tf.io.parse_example(protos, features) + with tf.control_dependencies(tensor_map.values()): + x = tf.constant(1) - return session.make_callable(x, feed_list=[protos]) + return session.make_callable(x, feed_list=[protos]) def _get_parse_tf_example_to_sparse_fn(session, features): - protos = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) + protos = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) - tensor_map = tf.io.parse_example(protos, features) - with tf.control_dependencies( - [sparse_tensor.indices for sparse_tensor in tensor_map.values()]): - x = tf.constant(1) + tensor_map = tf.io.parse_example(protos, features) + with tf.control_dependencies( + [sparse_tensor.indices for sparse_tensor in tensor_map.values()] + ): + x = tf.constant(1) - return session.make_callable(x, feed_list=[protos]) + return session.make_callable(x, feed_list=[protos]) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/struct2tensor/benchmarks/struct2tensor_benchmark_util.py b/struct2tensor/benchmarks/struct2tensor_benchmark_util.py index 466262d..d6019b7 100644 --- a/struct2tensor/benchmarks/struct2tensor_benchmark_util.py +++ b/struct2tensor/benchmarks/struct2tensor_benchmark_util.py @@ -24,14 +24,13 @@ from absl import flags from absl.testing import parameterized from tensorflow.core.protobuf import ( - rewriter_config_pb2, # pylint: disable=g-direct-tensorflow-import + rewriter_config_pb2, # pylint: disable=g-direct-tensorflow-import ) FLAGS = flags.FLAGS flags.DEFINE_bool( - "test_mode", False, - "if True, run all benchmarks with two iterations." + "test_mode", False, "if True, run all benchmarks with two iterations." ) _BASE_DIR = os.path.join(os.path.dirname(__file__), "testdata") @@ -39,242 +38,245 @@ class Struct2tensorBenchmarksBase(parameterized.TestCase): - """Base Class for Struct2tensor benchmarks. + """Base Class for Struct2tensor benchmarks. - These are tensorflow based benchmarks, and ensures that tensors are always - evaluated. - - The derived class should call - self.run_benchmarks(fn_name, get_benchmark_fn, fn_args, data_key). - """ - - def _discard_runs(self, benchmark_fn, inputs): - benchmark_fn(*inputs) - - def run_benchmarks(self, fn_name, get_benchmark_fn, fn_args, data_key): - """This benchmarks the function specified by `get_benchmark_fn`. - - Args: - fn_name: A string of the name of the function to benchmark. - get_benchmark_fn: A function that returns a tf.session callable. - fn_args: A list of arguments for get_benchmark_fn. - data_key: A string that determines what data to use for benchmarks. See - child class for possible defined data_keys. + These are tensorflow based benchmarks, and ensures that tensors are always + evaluated. + The derived class should call + self.run_benchmarks(fn_name, get_benchmark_fn, fn_args, data_key). """ - print(f"BEGIN {fn_name}:\tNum Iterations\tTotal (wall) Time (s)\t" - "Wall Time avg(ms)\tWall Time std\tUser CPU avg (ms)\t" - "User CPU std\tSystem CPU avg (ms)\tSystem CPU std") - - iterations = 1000 - # This is the number of iterations in a sample. We will benchmark the - # total compute time per sample. And find std across all samples. - sample_size = 100 - if FLAGS.test_mode: - print("WARNING: --test_mode is True. Setting iterations to 2 with sample " - "size of 1.") - iterations = 2 # 2 iterations so we can calculate stdev. - sample_size = 1 - - with ( - tf.Graph().as_default(), - # Disable control dependency optimization. Otherwise we cannot use the - # trick to force the parsing to happen without using its output. e.g.: - # with tf.control_dependencies([parsed_tensors]): - # return tf.constant(1) - tf.compat.v1.Session( - config=tf.compat.v1.ConfigProto( - graph_options=tf.compat.v1.GraphOptions( - rewrite_options=rewriter_config_pb2.RewriterConfig( - dependency_optimization=rewriter_config_pb2.RewriterConfig - .OFF)))) - as sess - ): - benchmark_fn = get_benchmark_fn(sess, *fn_args) - - for input_data in self._data[data_key]: - self._discard_runs(benchmark_fn, input_data) - - # Collect timings and run the benchmark function - name = f"{fn_name}_{len(input_data[0])}" - wall_times, user_cpu_times, system_cpu_times = [], [], [] - for _ in range(int(iterations / sample_size)): - start_cpu = psutil.cpu_times() - start_time = timeit.default_timer() - for _ in range(sample_size): - _ = benchmark_fn(*input_data) - end_cpu = psutil.cpu_times() - duration = timeit.default_timer() - start_time - cpu_user_duration = end_cpu.user - start_cpu.user - cpu_system_duration = end_cpu.system - start_cpu.system - - wall_times.append(duration * 1000) - user_cpu_times.append(cpu_user_duration * 1000) - system_cpu_times.append(cpu_system_duration * 1000) - - total_duration = sum(wall_times) - cpu_user_duration = sum(user_cpu_times) - cpu_system_duration = sum(system_cpu_times) - # The columns are |function name|Num Iterations|Total (wall) Time (s)| - # Wall Time avg(ms)|Wall Time std|User CPU avg (ms)|User CPU std| - # System CPU avg (ms)|System CPU std| - print( - f"{name}: \t{iterations}\t{total_duration / 1000}\t" - f"{total_duration / iterations}\t{statistics.stdev(wall_times)}\t" - f"{cpu_user_duration / iterations}\t" - f"{statistics.stdev(user_cpu_times)}\t" - f"{cpu_system_duration / iterations}\t" - f"{statistics.stdev(system_cpu_times)}" - ) + + def _discard_runs(self, benchmark_fn, inputs): + benchmark_fn(*inputs) + + def run_benchmarks(self, fn_name, get_benchmark_fn, fn_args, data_key): + """This benchmarks the function specified by `get_benchmark_fn`. + + Args: + fn_name: A string of the name of the function to benchmark. + get_benchmark_fn: A function that returns a tf.session callable. + fn_args: A list of arguments for get_benchmark_fn. + data_key: A string that determines what data to use for benchmarks. See + child class for possible defined data_keys. + + """ + print( + f"BEGIN {fn_name}:\tNum Iterations\tTotal (wall) Time (s)\t" + "Wall Time avg(ms)\tWall Time std\tUser CPU avg (ms)\t" + "User CPU std\tSystem CPU avg (ms)\tSystem CPU std" + ) + + iterations = 1000 + # This is the number of iterations in a sample. We will benchmark the + # total compute time per sample. And find std across all samples. + sample_size = 100 + if FLAGS.test_mode: + print( + "WARNING: --test_mode is True. Setting iterations to 2 with sample " + "size of 1." + ) + iterations = 2 # 2 iterations so we can calculate stdev. + sample_size = 1 + + with ( + tf.Graph().as_default(), + # Disable control dependency optimization. Otherwise we cannot use the + # trick to force the parsing to happen without using its output. e.g.: + # with tf.control_dependencies([parsed_tensors]): + # return tf.constant(1) + tf.compat.v1.Session( + config=tf.compat.v1.ConfigProto( + graph_options=tf.compat.v1.GraphOptions( + rewrite_options=rewriter_config_pb2.RewriterConfig( + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF + ) + ) + ) + ) as sess, + ): + benchmark_fn = get_benchmark_fn(sess, *fn_args) + + for input_data in self._data[data_key]: + self._discard_runs(benchmark_fn, input_data) + + # Collect timings and run the benchmark function + name = f"{fn_name}_{len(input_data[0])}" + wall_times, user_cpu_times, system_cpu_times = [], [], [] + for _ in range(int(iterations / sample_size)): + start_cpu = psutil.cpu_times() + start_time = timeit.default_timer() + for _ in range(sample_size): + _ = benchmark_fn(*input_data) + end_cpu = psutil.cpu_times() + duration = timeit.default_timer() - start_time + cpu_user_duration = end_cpu.user - start_cpu.user + cpu_system_duration = end_cpu.system - start_cpu.system + + wall_times.append(duration * 1000) + user_cpu_times.append(cpu_user_duration * 1000) + system_cpu_times.append(cpu_system_duration * 1000) + + total_duration = sum(wall_times) + cpu_user_duration = sum(user_cpu_times) + cpu_system_duration = sum(system_cpu_times) + # The columns are |function name|Num Iterations|Total (wall) Time (s)| + # Wall Time avg(ms)|Wall Time std|User CPU avg (ms)|User CPU std| + # System CPU avg (ms)|System CPU std| + print( + f"{name}: \t{iterations}\t{total_duration / 1000}\t" + f"{total_duration / iterations}\t{statistics.stdev(wall_times)}\t" + f"{cpu_user_duration / iterations}\t" + f"{statistics.stdev(user_cpu_times)}\t" + f"{cpu_system_duration / iterations}\t" + f"{statistics.stdev(system_cpu_times)}" + ) class ProtoDataBenchmarks(Struct2tensorBenchmarksBase): - """Base class for benchmarks that take proto data as input.""" - - @classmethod - def setUpClass(cls): - super().setUpClass() - print(f"cpuinfo: {_CPU_INFO}") - - shuffle_size = 2048 - - # Batch sizes are powers of two. i.e. 1, 2, 4, 8, .. up to 2048. - batch_sizes = [2**i for i in range(11)] - - if FLAGS.test_mode: - print("WARNING: --test_mode is True. Setting batch_size to 1.") - shuffle_size = 1 - batch_sizes = [1] - - # load tensor of serialized DeepProtos into memory. - ds = tf.data.TFRecordDataset( - os.path.join(_BASE_DIR, "deep_all_types_4096_positive.tfrecord.gz"), - "GZIP").shuffle(shuffle_size) - cls._deep_protos = [ - [list(ds.take(batch_size).as_numpy_iterator())] - for batch_size in batch_sizes - ] - - # load tensor of serialized FlatProtos into memory. - ds = tf.data.TFRecordDataset( - os.path.join(_BASE_DIR, "flat_all_types_4096_positive.tfrecord.gz"), - "GZIP").shuffle(shuffle_size) - cls._flat_protos = [ - [list(ds.take(batch_size).as_numpy_iterator())] - for batch_size in batch_sizes - ] - - # load tensor of serialized FlatProtos into memory. These protos each have - # 1 feature, each with 100 values. - ds = tf.data.TFRecordDataset( - os.path.join(_BASE_DIR, - "flat_all_types_100_int_values_4096.tfrecord.gz"), - "GZIP").shuffle(shuffle_size) - cls._flat_protos_1_feature_100_values = [ - [list(ds.take(batch_size).as_numpy_iterator())] - for batch_size in batch_sizes - ] - - # load tensor of serialized FlatProtos100 into memory. These protos each - # have 100 int features, where each feature has 1 value. - ds = tf.data.TFRecordDataset( - os.path.join(_BASE_DIR, "flat_100_int_features_4096.tfrecord.gz"), - "GZIP").shuffle(shuffle_size) - cls._flat_protos_100_features = [ - [list(ds.take(batch_size).as_numpy_iterator())] - for batch_size in batch_sizes - ] - - # load tensor of serialized FlatProtos100 into memory. These protos each - # have 100 int features, where each feature has 100 values. - ds = tf.data.TFRecordDataset( - os.path.join(_BASE_DIR, - "flat_100_int_features_100_values_4096.tfrecord.gz"), - "GZIP").shuffle(shuffle_size) - cls._flat_protos_100_features_100_values = [ - [list(ds.take(batch_size).as_numpy_iterator())] - for batch_size in batch_sizes - ] - - # load tf example into memory. Each example has 1 feature of list length 1. - ds = tf.data.TFRecordDataset( - os.path.join(_BASE_DIR, "tf_example_all_types_4096.tfrecord.gz"), - "GZIP").shuffle(shuffle_size) - cls._tf_examples = [ - [list(ds.take(batch_size).as_numpy_iterator())] - for batch_size in batch_sizes - ] - - # load tf example into memory. Each example has 1 feature of 100 values. - ds = tf.data.TFRecordDataset( - os.path.join(_BASE_DIR, - "tf_example_1_int_feature_100_values_4096.tfrecord.gz"), - "GZIP").shuffle(shuffle_size) - cls._tf_examples_1_feature_100_values = [ - [list(ds.take(batch_size).as_numpy_iterator())] - for batch_size in batch_sizes - ] - - # load tf example into memory. Each example has 100 features, where each - # feature has list length of 1. - ds = tf.data.TFRecordDataset( - os.path.join(_BASE_DIR, "tf_example_100_int_features_4096.tfrecord.gz"), - "GZIP").shuffle(shuffle_size) - cls._tf_examples_100_features = [ - [list(ds.take(batch_size).as_numpy_iterator())] - for batch_size in batch_sizes - ] - - # load tf example into memory. Each example has 100 features, where each - # feature has list length of 100. - ds = tf.data.TFRecordDataset( - os.path.join(_BASE_DIR, - "tf_example_100_int_features_100_values_4096.tfrecord.gz"), - "GZIP").shuffle(shuffle_size) - cls._tf_examples_100_features_100_values = [ - [list(ds.take(batch_size).as_numpy_iterator())] - for batch_size in batch_sizes - ] - - # pylint: disable=protected-access - cls._data = { - "deep_protos": - cls._deep_protos, - "flat_protos": - cls._flat_protos, - "flat_protos_1_feature_100_values": - cls._flat_protos_1_feature_100_values, - "flat_protos_100_features": - cls._flat_protos_100_features, - "flat_protos_100_features_100_values": - cls._flat_protos_100_features_100_values, - "tf_examples": - cls._tf_examples, - "tf_examples_1_feature_100_values": - cls._tf_examples_1_feature_100_values, - "tf_examples_100_features": - cls._tf_examples_100_features, - "tf_examples_100_features_100_values": - cls._tf_examples_100_features_100_values, - } - # pylint: enable=protected-access + """Base class for benchmarks that take proto data as input.""" + + @classmethod + def setUpClass(cls): + super().setUpClass() + print(f"cpuinfo: {_CPU_INFO}") + + shuffle_size = 2048 + + # Batch sizes are powers of two. i.e. 1, 2, 4, 8, .. up to 2048. + batch_sizes = [2**i for i in range(11)] + + if FLAGS.test_mode: + print("WARNING: --test_mode is True. Setting batch_size to 1.") + shuffle_size = 1 + batch_sizes = [1] + + # load tensor of serialized DeepProtos into memory. + ds = tf.data.TFRecordDataset( + os.path.join(_BASE_DIR, "deep_all_types_4096_positive.tfrecord.gz"), "GZIP" + ).shuffle(shuffle_size) + cls._deep_protos = [ + [list(ds.take(batch_size).as_numpy_iterator())] + for batch_size in batch_sizes + ] + + # load tensor of serialized FlatProtos into memory. + ds = tf.data.TFRecordDataset( + os.path.join(_BASE_DIR, "flat_all_types_4096_positive.tfrecord.gz"), "GZIP" + ).shuffle(shuffle_size) + cls._flat_protos = [ + [list(ds.take(batch_size).as_numpy_iterator())] + for batch_size in batch_sizes + ] + + # load tensor of serialized FlatProtos into memory. These protos each have + # 1 feature, each with 100 values. + ds = tf.data.TFRecordDataset( + os.path.join(_BASE_DIR, "flat_all_types_100_int_values_4096.tfrecord.gz"), + "GZIP", + ).shuffle(shuffle_size) + cls._flat_protos_1_feature_100_values = [ + [list(ds.take(batch_size).as_numpy_iterator())] + for batch_size in batch_sizes + ] + + # load tensor of serialized FlatProtos100 into memory. These protos each + # have 100 int features, where each feature has 1 value. + ds = tf.data.TFRecordDataset( + os.path.join(_BASE_DIR, "flat_100_int_features_4096.tfrecord.gz"), "GZIP" + ).shuffle(shuffle_size) + cls._flat_protos_100_features = [ + [list(ds.take(batch_size).as_numpy_iterator())] + for batch_size in batch_sizes + ] + + # load tensor of serialized FlatProtos100 into memory. These protos each + # have 100 int features, where each feature has 100 values. + ds = tf.data.TFRecordDataset( + os.path.join( + _BASE_DIR, "flat_100_int_features_100_values_4096.tfrecord.gz" + ), + "GZIP", + ).shuffle(shuffle_size) + cls._flat_protos_100_features_100_values = [ + [list(ds.take(batch_size).as_numpy_iterator())] + for batch_size in batch_sizes + ] + + # load tf example into memory. Each example has 1 feature of list length 1. + ds = tf.data.TFRecordDataset( + os.path.join(_BASE_DIR, "tf_example_all_types_4096.tfrecord.gz"), "GZIP" + ).shuffle(shuffle_size) + cls._tf_examples = [ + [list(ds.take(batch_size).as_numpy_iterator())] + for batch_size in batch_sizes + ] + + # load tf example into memory. Each example has 1 feature of 100 values. + ds = tf.data.TFRecordDataset( + os.path.join( + _BASE_DIR, "tf_example_1_int_feature_100_values_4096.tfrecord.gz" + ), + "GZIP", + ).shuffle(shuffle_size) + cls._tf_examples_1_feature_100_values = [ + [list(ds.take(batch_size).as_numpy_iterator())] + for batch_size in batch_sizes + ] + + # load tf example into memory. Each example has 100 features, where each + # feature has list length of 1. + ds = tf.data.TFRecordDataset( + os.path.join(_BASE_DIR, "tf_example_100_int_features_4096.tfrecord.gz"), + "GZIP", + ).shuffle(shuffle_size) + cls._tf_examples_100_features = [ + [list(ds.take(batch_size).as_numpy_iterator())] + for batch_size in batch_sizes + ] + + # load tf example into memory. Each example has 100 features, where each + # feature has list length of 100. + ds = tf.data.TFRecordDataset( + os.path.join( + _BASE_DIR, "tf_example_100_int_features_100_values_4096.tfrecord.gz" + ), + "GZIP", + ).shuffle(shuffle_size) + cls._tf_examples_100_features_100_values = [ + [list(ds.take(batch_size).as_numpy_iterator())] + for batch_size in batch_sizes + ] + + # pylint: disable=protected-access + cls._data = { + "deep_protos": cls._deep_protos, + "flat_protos": cls._flat_protos, + "flat_protos_1_feature_100_values": cls._flat_protos_1_feature_100_values, + "flat_protos_100_features": cls._flat_protos_100_features, + "flat_protos_100_features_100_values": cls._flat_protos_100_features_100_values, + "tf_examples": cls._tf_examples, + "tf_examples_1_feature_100_values": cls._tf_examples_1_feature_100_values, + "tf_examples_100_features": cls._tf_examples_100_features, + "tf_examples_100_features_100_values": cls._tf_examples_100_features_100_values, + } + # pylint: enable=protected-access class OpsBenchmarks(Struct2tensorBenchmarksBase): - """Base class for benchmarks that take tensor data as input.""" + """Base class for benchmarks that take tensor data as input.""" - @classmethod - def setUpClass(cls): - super().setUpClass() + @classmethod + def setUpClass(cls): + super().setUpClass() - print(f"cpuinfo: {_CPU_INFO}") - a = list(range(1000)) - b = list(range(1000)) + print(f"cpuinfo: {_CPU_INFO}") + a = list(range(1000)) + b = list(range(1000)) - rand_a = list(range(1000)) - rand_b = list(range(1000)) - random.shuffle(rand_a) - random.shuffle(rand_b) + rand_a = list(range(1000)) + rand_b = list(range(1000)) + random.shuffle(rand_a) + random.shuffle(rand_b) - cls._data = {"monotonic_increasing": [[a, b]], - "random": [[rand_a, rand_b]]} + cls._data = {"monotonic_increasing": [[a, b]], "random": [[rand_a, rand_b]]} diff --git a/struct2tensor/calculate.py b/struct2tensor/calculate.py index 0b187ea..92111f5 100644 --- a/struct2tensor/calculate.py +++ b/struct2tensor/calculate.py @@ -57,421 +57,437 @@ def calculate_values_with_graph( expressions: List[expression.Expression], options: Optional[calculate_options.Options] = None, - feed_dict: Optional[Dict[expression.Expression, prensor.Prensor]] = None + feed_dict: Optional[Dict[expression.Expression, prensor.Prensor]] = None, ) -> Tuple[List[prensor.NodeTensor], "ExpressionGraph"]: - """Calculates the values of the expressions, and the graph used. + """Calculates the values of the expressions, and the graph used. - Note that this does not return prensors, but instead a list of NodeTensors. + Note that this does not return prensors, but instead a list of NodeTensors. - Args: - expressions: a list of expressions to calculate. - options: options for calculations, passed to calculate(...). - feed_dict: a dictionary, mapping expression to prensor that will be used - as the initial expression in the expression graph. + Args: + expressions: a list of expressions to calculate. + options: options for calculations, passed to calculate(...). + feed_dict: a dictionary, mapping expression to prensor that will be used + as the initial expression in the expression graph. - Returns: - the list of values and the graph used to calculate them. - """ - if options is None: - options = calculate_options.get_default_options() - expression_graph = _create_graph(expressions, options, feed_dict=feed_dict) - return ([expression_graph.get_value_or_die(x) for x in expressions], - expression_graph) + Returns: + the list of values and the graph used to calculate them. + """ + if options is None: + options = calculate_options.get_default_options() + expression_graph = _create_graph(expressions, options, feed_dict=feed_dict) + return ( + [expression_graph.get_value_or_die(x) for x in expressions], + expression_graph, + ) def calculate_values( expressions: List[expression.Expression], options: Optional[calculate_options.Options] = None, - feed_dict: Optional[Dict[expression.Expression, prensor.Prensor]] = None + feed_dict: Optional[Dict[expression.Expression, prensor.Prensor]] = None, ) -> List[prensor.NodeTensor]: - """Calculates the values of the expressions. + """Calculates the values of the expressions. - Note that this does not return prensors, but instead a list of NodeTensors. + Note that this does not return prensors, but instead a list of NodeTensors. - Args: - expressions: A list of expressions to calculate. - options: options for calculate(...) operations. - feed_dict: a dictionary, mapping expression to prensor that will be used - as the initial expression in the expression graph. + Args: + expressions: A list of expressions to calculate. + options: options for calculate(...) operations. + feed_dict: a dictionary, mapping expression to prensor that will be used + as the initial expression in the expression graph. - Returns: - A list of NodeTensor values. - """ - return calculate_values_with_graph( - expressions, options=options, feed_dict=feed_dict)[0] + Returns: + A list of NodeTensor values. + """ + return calculate_values_with_graph( + expressions, options=options, feed_dict=feed_dict + )[0] def calculate_prensors_with_graph( expressions: Sequence[expression.Expression], options: Optional[calculate_options.Options] = None, - feed_dict: Optional[Dict[expression.Expression, prensor.Prensor]] = None + feed_dict: Optional[Dict[expression.Expression, prensor.Prensor]] = None, ) -> Tuple[Sequence[prensor.Prensor], "ExpressionGraph"]: - """Gets the prensor value of the expressions and the graph used. - - This method is useful for getting information like the protobuf fields parsed - to create an expression. - - Args: - expressions: expressions to calculate prensors for. - options: options for calculate(...) methods. - feed_dict: a dictionary, mapping expression to prensor that will be used - as the initial expression in the expression graph. - - Returns: - a list of prensors, and the graph used to calculate them. - """ - subtrees = [x.get_known_descendants() for x in expressions] - all_expressions = [] - for tree in subtrees: - all_expressions.extend(tree.values()) - values, graph = calculate_values_with_graph( - all_expressions, options=options, feed_dict=feed_dict) - expr_value_pairs = zip(all_expressions, values) - value_map = {} - for expr, value in expr_value_pairs: - if id(expr) not in value_map: - value_map[id(expr)] = value - return ([_get_prensor(subtree, value_map) for subtree in subtrees], graph) + """Gets the prensor value of the expressions and the graph used. + + This method is useful for getting information like the protobuf fields parsed + to create an expression. + + Args: + expressions: expressions to calculate prensors for. + options: options for calculate(...) methods. + feed_dict: a dictionary, mapping expression to prensor that will be used + as the initial expression in the expression graph. + + Returns: + a list of prensors, and the graph used to calculate them. + """ + subtrees = [x.get_known_descendants() for x in expressions] + all_expressions = [] + for tree in subtrees: + all_expressions.extend(tree.values()) + values, graph = calculate_values_with_graph( + all_expressions, options=options, feed_dict=feed_dict + ) + expr_value_pairs = zip(all_expressions, values) + value_map = {} + for expr, value in expr_value_pairs: + if id(expr) not in value_map: + value_map[id(expr)] = value + return ([_get_prensor(subtree, value_map) for subtree in subtrees], graph) def calculate_prensors( expressions: Sequence[expression.Expression], options: Optional[calculate_options.Options] = None, - feed_dict: Optional[Dict[expression.Expression, prensor.Prensor]] = None + feed_dict: Optional[Dict[expression.Expression, prensor.Prensor]] = None, ) -> Sequence[prensor.Prensor]: - """Gets the prensor value of the expressions. + """Gets the prensor value of the expressions. - Args: - expressions: expressions to calculate prensors for. - options: options for calculate(...). - feed_dict: a dictionary, mapping expression to prensor that will be used - as the initial expression in the expression graph. + Args: + expressions: expressions to calculate prensors for. + options: options for calculate(...). + feed_dict: a dictionary, mapping expression to prensor that will be used + as the initial expression in the expression graph. - Returns: - a list of prensors. - """ - return calculate_prensors_with_graph( - expressions, options=options, feed_dict=feed_dict)[0] + Returns: + a list of prensors. + """ + return calculate_prensors_with_graph( + expressions, options=options, feed_dict=feed_dict + )[0] # TODO(martinz): Create an option to create the original expression graph. def _create_graph( expressions: List[expression.Expression], options: calculate_options.Options, - feed_dict: Optional[Dict[expression.Expression, prensor.Prensor]] = None + feed_dict: Optional[Dict[expression.Expression, prensor.Prensor]] = None, ) -> "ExpressionGraph": - """Create graph and calculate expressions.""" - expression_graph = OriginalExpressionGraph(expressions) - canonical_graph = CanonicalExpressionGraph(expression_graph) - canonical_graph.calculate_values(options, feed_dict=feed_dict) - return canonical_graph + """Create graph and calculate expressions.""" + expression_graph = OriginalExpressionGraph(expressions) + canonical_graph = CanonicalExpressionGraph(expression_graph) + canonical_graph.calculate_values(options, feed_dict=feed_dict) + return canonical_graph -def _get_prensor(subtree: Mapping[path.Path, expression.Expression], - values: Mapping[IDNodeTensor, prensor.NodeTensor] - ) -> prensor.Prensor: - """Gets the prensor tree value of the subtree. - - Args: - subtree: a mapping paths to expressions returned by - expression.get_known_descendants. - values: a mapping from expression ids to node values. - - Returns: - a prensor tree. - """ - return prensor.create_prensor_from_descendant_nodes( - {k: values[id(v)] for k, v in subtree.items()}) - - -def _get_earliest_equal_calculation(expr: expression.Expression): - """Finds an expression with an equal value. - - Recursively traverses sources while expressions are the identity. - - Args: - expr: the expression to find an equal value. - - Returns: - An expression with an equal value to the input. - """ - result = expr - while result.calculation_is_identity(): - result = result.get_source_expressions()[0] - return result - - -def _fancy_type_str(is_repeated: bool, dtype: Optional[tf.DType]): - return "{} {}".format("repeated" if is_repeated else "optional", dtype) - - -def _node_type_str(node_tensor: prensor.NodeTensor): - if isinstance(node_tensor, prensor.LeafNodeTensor): - return _fancy_type_str(node_tensor.is_repeated, node_tensor.values.dtype) - return _fancy_type_str(node_tensor.is_repeated, None) - - -class _ExpressionNode: - """A node representing an expression in the ExpressionGraph.""" - - def __init__(self, expr: expression.Expression): - """Construct a node in the graph. - - Args: - expr: must be the result of _get_earliest_equal_calculation(...) - """ - self.expression = expr - self.sources = [ - _get_earliest_equal_calculation(x) - for x in expr.get_source_expressions() - ] - self.destinations = [] # type: List[_ExpressionNode] - self.value = None - - def __eq__(self, node: "_ExpressionNode") -> bool: - """Test if this node is equal to the other. - - Requires that all sources are already canonical. +def _get_prensor( + subtree: Mapping[path.Path, expression.Expression], + values: Mapping[IDNodeTensor, prensor.NodeTensor], +) -> prensor.Prensor: + """Gets the prensor tree value of the subtree. Args: - node: Another node to compare to. + subtree: a mapping paths to expressions returned by + expression.get_known_descendants. + values: a mapping from expression ids to node values. Returns: - True if the nodes are guaranteed to have equal value. + a prensor tree. """ - if not self.expression.calculation_equal(node.expression): - return False - # This assumes the sources are canonical, so we check for identity equality - # as opposed to semantic equality. - return all([a is b for a, b in zip(self.sources, node.sources)]) - - def __str__(self) -> str: - return (f"expression: {str(self.expression)} sources: {str(self.sources)} destinations: " - f"{str(self.destinations)} value: {str(self.value)}") - - def _create_value_error(self) -> ValueError: - """Creates a ValueError, assuming there should be one for this node.""" - return ValueError("Expression {} returned the wrong type:" - " expected: {}" - " actual: {}.".format( - self.expression, - _fancy_type_str(self.expression.is_repeated, - self.expression.type), - _node_type_str(self.value))) - - def calculate(self, - source_values: Sequence[prensor.NodeTensor], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor]) -> None: - """Calculate the value of the node, and store it in self.value.""" - self.value = self.expression.calculate( - source_values, [x.expression for x in self.destinations], - options, - side_info=side_info) - if self.value.is_repeated != self.expression.is_repeated: - raise self._create_value_error() - expected_type = self.expression.type - actual_value = self.value - if expected_type is None: - if not isinstance(actual_value, (prensor.RootNodeTensor, prensor.ChildNodeTensor)): - raise self._create_value_error() - elif isinstance(actual_value, prensor.LeafNodeTensor): - if expected_type != actual_value.values.dtype: - raise self._create_value_error() - else: - raise self._create_value_error() - - def __hash__(self) -> int: - """This assumes all sources are canonical.""" - return hash(tuple([id(x) for x in self.sources])) + return prensor.create_prensor_from_descendant_nodes( + {k: values[id(v)] for k, v in subtree.items()} + ) -class ExpressionGraph: - """A graph representing the computation of a list of expressions.""" - - def __init__(self): - self._node = {} # type: Dict[IDExpression, _ExpressionNode] - # An ordered list of nodes. - self._ordered_node_list = [] # type: List[_ExpressionNode] - - @property - def ordered_node_list(self): - """Do not add or delete elements.""" - return self._ordered_node_list +def _get_earliest_equal_calculation(expr: expression.Expression): + """Finds an expression with an equal value. - def _get_node(self, expr: expression.Expression) -> Optional[_ExpressionNode]: - """Gets a node corresponding to the expression. + Recursively traverses sources while expressions are the identity. Args: - expr: the expression to look up. + expr: the expression to find an equal value. Returns: - A node for an expression that returns the same values as the - expression. + An expression with an equal value to the input. """ - earliest = _get_earliest_equal_calculation(expr) - return self._node.get(id(earliest)) - - def get_value(self, - expr: expression.Expression) -> Optional[prensor.NodeTensor]: - node = self._get_node(expr) - return None if node is None else node.value - - def get_value_or_die(self, expr: expression.Expression) -> prensor.NodeTensor: - result = self.get_value(expr) - if result is None: - raise ValueError("Could not find expression's value") + result = expr + while result.calculation_is_identity(): + result = result.get_source_expressions()[0] return result - def _find_destinations(self, nodes: Sequence[_ExpressionNode]) -> None: - """Initialize destinations of nodes. - - For all nodes u,v in nodes: - if u is a source of v, make v a destination of u. - Args: - nodes: a list containing every node in the graph exactly once. - """ - for node in nodes: - for source in node.sources: - source_node = self._get_node(source) - if source_node is None: - raise ValueError("Could not find source of node") - source_node.destinations.append(node) - - def calculate_values( - self, - options: calculate_options.Options, - feed_dict: Optional[Dict[expression.Expression, prensor.Prensor]] = None - ) -> None: - for node in self.ordered_node_list: - source_values = [self._node[id(x)].value for x in node.sources] - side_info = feed_dict[node.expression] if feed_dict and ( - node.expression in feed_dict) else None - node.calculate(source_values, options, side_info=side_info) - - def get_expressions_needed(self) -> Sequence[expression.Expression]: - return [x.expression for x in self.ordered_node_list] - - def __str__(self): - return str([str(x) for x in self.ordered_node_list]) +def _fancy_type_str(is_repeated: bool, dtype: Optional[tf.DType]): + return "{} {}".format("repeated" if is_repeated else "optional", dtype) -class OriginalExpressionGraph(ExpressionGraph): - """A graph representing the computation of a list of expressions. - - This can directly consume a list of expressions, and create a graph. - In terms of calculating the values, the primary purpose here is to enumerate - all expressions. - """ - - def __init__(self, expressions: Sequence[expression.Expression]): - super().__init__() - self._add_expressions(expressions) - original_nodes = list(self._node.values()) - self._find_destinations(original_nodes) - self._order_nodes() - - def _add_expressions(self, - expressions: Sequence[expression.Expression]) -> None: - """Add expressions to the graph.""" - if self.ordered_node_list: - raise ValueError("Only call once during construction") - to_add = list(expressions) - while to_add: - expr = to_add.pop() - if self._get_node(expr) is None: - earliest_expr = _get_earliest_equal_calculation(expr) - node = _ExpressionNode(earliest_expr) - self._node[id(earliest_expr)] = node - to_add.extend(node.sources) - - def _order_nodes(self) -> None: - """Topologically sorts the nodes and puts them in ordered_node_list.""" - nodes_to_process = [] # type: List[_ExpressionNode] - if self.ordered_node_list: - raise ValueError("Only call once during construction") - expr_id_to_count = { - expr_id: len(node.sources) for expr_id, node in self._node.items() - } - for k, v in sorted(expr_id_to_count.items()): - if v == 0: - nodes_to_process.append(self._node[k]) - while nodes_to_process: - node = nodes_to_process.pop() - self._ordered_node_list.append(node) - for dest in node.destinations: - count = expr_id_to_count[id(dest.expression)] - 1 - expr_id_to_count[id(dest.expression)] = count - if count == 0: - nodes_to_process.append(dest) +def _node_type_str(node_tensor: prensor.NodeTensor): + if isinstance(node_tensor, prensor.LeafNodeTensor): + return _fancy_type_str(node_tensor.is_repeated, node_tensor.values.dtype) + return _fancy_type_str(node_tensor.is_repeated, None) -class CanonicalExpressionGraph(ExpressionGraph): - """A graph representing the computation of a list of expressions. +class _ExpressionNode: + """A node representing an expression in the ExpressionGraph.""" + + def __init__(self, expr: expression.Expression): + """Construct a node in the graph. + + Args: + expr: must be the result of _get_earliest_equal_calculation(...) + """ + self.expression = expr + self.sources = [ + _get_earliest_equal_calculation(x) for x in expr.get_source_expressions() + ] + self.destinations = [] # type: List[_ExpressionNode] + self.value = None + + def __eq__(self, node: "_ExpressionNode") -> bool: + """Test if this node is equal to the other. + + Requires that all sources are already canonical. + + Args: + node: Another node to compare to. + + Returns: + True if the nodes are guaranteed to have equal value. + """ + if not self.expression.calculation_equal(node.expression): + return False + # This assumes the sources are canonical, so we check for identity equality + # as opposed to semantic equality. + return all([a is b for a, b in zip(self.sources, node.sources)]) + + def __str__(self) -> str: + return ( + f"expression: {str(self.expression)} sources: {str(self.sources)} destinations: " + f"{str(self.destinations)} value: {str(self.value)}" + ) + + def _create_value_error(self) -> ValueError: + """Creates a ValueError, assuming there should be one for this node.""" + return ValueError( + "Expression {} returned the wrong type: expected: {} actual: {}.".format( + self.expression, + _fancy_type_str(self.expression.is_repeated, self.expression.type), + _node_type_str(self.value), + ) + ) + + def calculate( + self, + source_values: Sequence[prensor.NodeTensor], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor], + ) -> None: + """Calculate the value of the node, and store it in self.value.""" + self.value = self.expression.calculate( + source_values, + [x.expression for x in self.destinations], + options, + side_info=side_info, + ) + if self.value.is_repeated != self.expression.is_repeated: + raise self._create_value_error() + expected_type = self.expression.type + actual_value = self.value + if expected_type is None: + if not isinstance( + actual_value, (prensor.RootNodeTensor, prensor.ChildNodeTensor) + ): + raise self._create_value_error() + elif isinstance(actual_value, prensor.LeafNodeTensor): + if expected_type != actual_value.values.dtype: + raise self._create_value_error() + else: + raise self._create_value_error() + + def __hash__(self) -> int: + """This assumes all sources are canonical.""" + return hash(tuple([id(x) for x in self.sources])) - In the construction of the graph, if two expressions have inputs which can be - proven to have equal outputs, and the two expressions have equal calculations, - then they are represented by a single "canonical" node. - This canonicalization is done sequentially through the ordered_node_list of - the input graph (to avoid Python recursion limits). - """ +class ExpressionGraph: + """A graph representing the computation of a list of expressions.""" + + def __init__(self): + self._node = {} # type: Dict[IDExpression, _ExpressionNode] + # An ordered list of nodes. + self._ordered_node_list = [] # type: List[_ExpressionNode] + + @property + def ordered_node_list(self): + """Do not add or delete elements.""" + return self._ordered_node_list + + def _get_node(self, expr: expression.Expression) -> Optional[_ExpressionNode]: + """Gets a node corresponding to the expression. + + Args: + expr: the expression to look up. + + Returns: + A node for an expression that returns the same values as the + expression. + """ + earliest = _get_earliest_equal_calculation(expr) + return self._node.get(id(earliest)) + + def get_value(self, expr: expression.Expression) -> Optional[prensor.NodeTensor]: + node = self._get_node(expr) + return None if node is None else node.value + + def get_value_or_die(self, expr: expression.Expression) -> prensor.NodeTensor: + result = self.get_value(expr) + if result is None: + raise ValueError("Could not find expression's value") + return result + + def _find_destinations(self, nodes: Sequence[_ExpressionNode]) -> None: + """Initialize destinations of nodes. + + For all nodes u,v in nodes: + if u is a source of v, make v a destination of u. + + Args: + nodes: a list containing every node in the graph exactly once. + """ + for node in nodes: + for source in node.sources: + source_node = self._get_node(source) + if source_node is None: + raise ValueError("Could not find source of node") + source_node.destinations.append(node) + + def calculate_values( + self, + options: calculate_options.Options, + feed_dict: Optional[Dict[expression.Expression, prensor.Prensor]] = None, + ) -> None: + for node in self.ordered_node_list: + source_values = [self._node[id(x)].value for x in node.sources] + side_info = ( + feed_dict[node.expression] + if feed_dict and (node.expression in feed_dict) + else None + ) + node.calculate(source_values, options, side_info=side_info) + + def get_expressions_needed(self) -> Sequence[expression.Expression]: + return [x.expression for x in self.ordered_node_list] + + def __str__(self): + return str([str(x) for x in self.ordered_node_list]) - def __init__(self, original: ExpressionGraph): - super().__init__() - # Nodes indexed by _ExpressionNode. - self._node_map = {} # type: Dict[_ExpressionNode, _ExpressionNode] - self._add_expressions([x.expression for x in original.ordered_node_list]) - self._find_destinations(self.ordered_node_list) - def _add_expressions(self, - expressions: Sequence[expression.Expression]) -> None: - """Add expressions to the graph. +class OriginalExpressionGraph(ExpressionGraph): + """A graph representing the computation of a list of expressions. - Args: - expressions: an ordered list, where sources are before expressions, and - for all x in expressions, x == _get_earliest_equal_calculation(x) + This can directly consume a list of expressions, and create a graph. + In terms of calculating the values, the primary purpose here is to enumerate + all expressions. """ - for expr in expressions: - maybe_node = self._create_node_if_not_exists(expr) - if maybe_node is not None: - self._ordered_node_list.append(maybe_node) - def _create_node_if_not_exists(self, expr: expression.Expression - ) -> Optional[_ExpressionNode]: - """Creates a canonical node for an expression if none exists. + def __init__(self, expressions: Sequence[expression.Expression]): + super().__init__() + self._add_expressions(expressions) + original_nodes = list(self._node.values()) + self._find_destinations(original_nodes) + self._order_nodes() + + def _add_expressions(self, expressions: Sequence[expression.Expression]) -> None: + """Add expressions to the graph.""" + if self.ordered_node_list: + raise ValueError("Only call once during construction") + to_add = list(expressions) + while to_add: + expr = to_add.pop() + if self._get_node(expr) is None: + earliest_expr = _get_earliest_equal_calculation(expr) + node = _ExpressionNode(earliest_expr) + self._node[id(earliest_expr)] = node + to_add.extend(node.sources) + + def _order_nodes(self) -> None: + """Topologically sorts the nodes and puts them in ordered_node_list.""" + nodes_to_process = [] # type: List[_ExpressionNode] + if self.ordered_node_list: + raise ValueError("Only call once during construction") + expr_id_to_count = { + expr_id: len(node.sources) for expr_id, node in self._node.items() + } + for k, v in sorted(expr_id_to_count.items()): + if v == 0: + nodes_to_process.append(self._node[k]) + while nodes_to_process: + node = nodes_to_process.pop() + self._ordered_node_list.append(node) + for dest in node.destinations: + count = expr_id_to_count[id(dest.expression)] - 1 + expr_id_to_count[id(dest.expression)] = count + if count == 0: + nodes_to_process.append(dest) - This method assumes that the method has already been called on all - pre-existing nodes. - After this method is called, self._get_canonical_or_error(expr) will - succeed. +class CanonicalExpressionGraph(ExpressionGraph): + """A graph representing the computation of a list of expressions. - Args: - expr: an expression to be canonicalized. + In the construction of the graph, if two expressions have inputs which can be + proven to have equal outputs, and the two expressions have equal calculations, + then they are represented by a single "canonical" node. - Returns: - a new node, or None if the new node is not created. + This canonicalization is done sequentially through the ordered_node_list of + the input graph (to avoid Python recursion limits). """ - if self._get_node(expr) is not None: - return None - maybe_canonical = _ExpressionNode(expr) - # The sources are already the earliest equal calculation. - maybe_canonical.sources = [ - self._get_canonical_or_error(x) for x in maybe_canonical.sources - ] - if maybe_canonical in self._node_map: - self._node[id(expr)] = self._node_map[maybe_canonical] - return None - self._node_map[maybe_canonical] = maybe_canonical - self._node[id(expr)] = maybe_canonical - return maybe_canonical - - def _get_canonical_or_error(self, expr: expression.Expression - ) -> expression.Expression: - """Gets a canonical expression or dies.""" - node = self._get_node(expr) - if node is not None: - return node.expression - raise ValueError("Expression not found: " + str(expr)) + + def __init__(self, original: ExpressionGraph): + super().__init__() + # Nodes indexed by _ExpressionNode. + self._node_map = {} # type: Dict[_ExpressionNode, _ExpressionNode] + self._add_expressions([x.expression for x in original.ordered_node_list]) + self._find_destinations(self.ordered_node_list) + + def _add_expressions(self, expressions: Sequence[expression.Expression]) -> None: + """Add expressions to the graph. + + Args: + expressions: an ordered list, where sources are before expressions, and + for all x in expressions, x == _get_earliest_equal_calculation(x) + """ + for expr in expressions: + maybe_node = self._create_node_if_not_exists(expr) + if maybe_node is not None: + self._ordered_node_list.append(maybe_node) + + def _create_node_if_not_exists( + self, expr: expression.Expression + ) -> Optional[_ExpressionNode]: + """Creates a canonical node for an expression if none exists. + + This method assumes that the method has already been called on all + pre-existing nodes. + + After this method is called, self._get_canonical_or_error(expr) will + succeed. + + Args: + expr: an expression to be canonicalized. + + Returns: + a new node, or None if the new node is not created. + """ + if self._get_node(expr) is not None: + return None + maybe_canonical = _ExpressionNode(expr) + # The sources are already the earliest equal calculation. + maybe_canonical.sources = [ + self._get_canonical_or_error(x) for x in maybe_canonical.sources + ] + if maybe_canonical in self._node_map: + self._node[id(expr)] = self._node_map[maybe_canonical] + return None + self._node_map[maybe_canonical] = maybe_canonical + self._node[id(expr)] = maybe_canonical + return maybe_canonical + + def _get_canonical_or_error( + self, expr: expression.Expression + ) -> expression.Expression: + """Gets a canonical expression or dies.""" + node = self._get_node(expr) + if node is not None: + return node.expression + raise ValueError("Expression not found: " + str(expr)) diff --git a/struct2tensor/calculate_options.py b/struct2tensor/calculate_options.py index e8b311e..c06ee75 100644 --- a/struct2tensor/calculate_options.py +++ b/struct2tensor/calculate_options.py @@ -20,47 +20,52 @@ class Options: - """Options for calculate functions. + """Options for calculate functions. - Do not construct Options directly. The preferred method of creating an object - is calling get_default_options() or get_options_with_minimal_checks() below. - Any fine-tuning can be done by modifying the properties of the Options object - after creation. + Do not construct Options directly. The preferred method of creating an object + is calling get_default_options() or get_options_with_minimal_checks() below. + Any fine-tuning can be done by modifying the properties of the Options object + after creation. - When a method takes an optional Options object but none is provided, it will - replace it with get_default_options() . + When a method takes an optional Options object but none is provided, it will + replace it with get_default_options() . - Available options: - ragged_checks: if True, add assertion ops when converting a Prensor object - to RaggedTensors. - sparse_checks: if True, add assertion ops when converting a Prensor object - to SparseTensors. - use_string_view: if True, decode sub-messages into string views to avoid - copying. - experimental_honor_proto3_optional_semantics: if True, if a proto3 primitive - optional field without the presence semantic (i.e. the field is without - the "optional" or "repeated" label) is requested to be parsed, it will - always have a value for each input parent message. If a value is not - present on wire, the default value (0 or "") will be used. - """ + Available options: + ragged_checks: if True, add assertion ops when converting a Prensor object + to RaggedTensors. + sparse_checks: if True, add assertion ops when converting a Prensor object + to SparseTensors. + use_string_view: if True, decode sub-messages into string views to avoid + copying. + experimental_honor_proto3_optional_semantics: if True, if a proto3 primitive + optional field without the presence semantic (i.e. the field is without + the "optional" or "repeated" label) is requested to be parsed, it will + always have a value for each input parent message. If a value is not + present on wire, the default value (0 or "") will be used. + """ - def __init__(self, ragged_checks: bool, sparse_checks: bool): - """Create options.""" - self.ragged_checks = ragged_checks - self.sparse_checks = sparse_checks - self.use_string_view = False - self.experimental_honor_proto3_optional_semantics = False + def __init__(self, ragged_checks: bool, sparse_checks: bool): + """Create options.""" + self.ragged_checks = ragged_checks + self.sparse_checks = sparse_checks + self.use_string_view = False + self.experimental_honor_proto3_optional_semantics = False - def __str__(self): - return ("{ragged_checks:" + str(self.ragged_checks) + ", sparse_checks: " + - str(self.sparse_checks) + "}") + def __str__(self): + return ( + "{ragged_checks:" + + str(self.ragged_checks) + + ", sparse_checks: " + + str(self.sparse_checks) + + "}" + ) def get_default_options() -> Options: - """Get the default options.""" - return Options(ragged_checks=True, sparse_checks=True) + """Get the default options.""" + return Options(ragged_checks=True, sparse_checks=True) def get_options_with_minimal_checks() -> Options: - """Options for calculation with minimal runtime checks.""" - return Options(ragged_checks=False, sparse_checks=False) + """Options for calculation with minimal runtime checks.""" + return Options(ragged_checks=False, sparse_checks=False) diff --git a/struct2tensor/calculate_test.py b/struct2tensor/calculate_test.py index 1a2b5f5..d4a0914 100644 --- a/struct2tensor/calculate_test.py +++ b/struct2tensor/calculate_test.py @@ -18,183 +18,199 @@ import tensorflow as tf from absl.testing import absltest from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor import ( - calculate, - calculate_options, - create_expression, - expression_add, - path, + calculate, + calculate_options, + create_expression, + expression_add, + path, ) from struct2tensor.expression_impl import promote, proto_test_util from struct2tensor.test import expression_test_util, prensor_test_util def get_mock_linear_graph(length): - result = [expression_test_util.get_mock_leaf(True, tf.int64, name="Node0")] - for i in range(1, length): - result.append( - expression_test_util.get_mock_leaf( - False, - tf.int64, - name="Node" + str(i), - source_expressions=[result[i - 1]])) - return result[-1] + result = [expression_test_util.get_mock_leaf(True, tf.int64, name="Node0")] + for i in range(1, length): + result.append( + expression_test_util.get_mock_leaf( + False, + tf.int64, + name="Node" + str(i), + source_expressions=[result[i - 1]], + ) + ) + return result[-1] def create_random_graph_helper(size, density, fraction_identity): - partial = [] - if size: - partial = create_random_graph_helper(size - 1, density, fraction_identity) - if random.random() < fraction_identity and partial: - source_expressions = [partial[random.randint(0, len(partial) - 1)]] + partial = [] + if size: + partial = create_random_graph_helper(size - 1, density, fraction_identity) + if random.random() < fraction_identity and partial: + source_expressions = [partial[random.randint(0, len(partial) - 1)]] + partial.append( + expression_test_util.get_mock_leaf( + False, + tf.int64, + name="Node" + str(size - 1), + source_expressions=source_expressions, + calculate_is_identity=True, + ) + ) + return partial + + source_expressions = [] + for x in partial: + if random.random() < density: + source_expressions.append(x) partial.append( expression_test_util.get_mock_leaf( False, tf.int64, name="Node" + str(size - 1), source_expressions=source_expressions, - calculate_is_identity=True)) + ) + ) return partial - source_expressions = [] - for x in partial: - if random.random() < density: - source_expressions.append(x) - partial.append( - expression_test_util.get_mock_leaf( - False, - tf.int64, - name="Node" + str(size - 1), - source_expressions=source_expressions)) - return partial - def create_random_graph(size, density, fraction_identity): - return create_random_graph_helper(size, density, fraction_identity)[-1] + return create_random_graph_helper(size, density, fraction_identity)[-1] options_to_test = [ calculate_options.get_default_options(), - calculate_options.get_options_with_minimal_checks(), None + calculate_options.get_options_with_minimal_checks(), + None, ] @test_util.run_all_in_graph_and_eager_modes class CalculateTest(tf.test.TestCase): - - def test_calculate_mock(self): - my_input = get_mock_linear_graph(5) - [result] = calculate.calculate_values([my_input]) - self.assertIs(my_input._calculate_output, result) - - def test_calculate_broken_mock_is_repeated(self): - with self.assertRaisesRegex( - ValueError, - "Expression Node0 returned the wrong type: expected: repeated actual: optional .", - ): - single_node = expression_test_util.get_mock_broken_leaf( - True, tf.int64, False, tf.int64, name="Node0") - calculate.calculate_values([single_node]) - - def test_calculate_broken_mock_dtype(self): - with self.assertRaisesRegex( - ValueError, - "Expression Node0 returned the " - "wrong type: expected: repeated actual: repeated .", - ): - single_node = expression_test_util.get_mock_broken_leaf( - True, tf.int64, True, tf.int32, name="Node0") - calculate.calculate_values([single_node]) - - def test_calculate_random_mock(self): - """Test calculate on 1000 graphs with 20 nodes. - - This will have identity operations. - This will not have equal operations (outside of identity operations). - All nodes in the graph are leaf expressions. - """ - random.seed(a=12345) - for options in options_to_test: - for _ in range(1000): - my_input = create_random_graph(20, 0.5, 0.2) - [result] = calculate.calculate_values([my_input], options=options) + def test_calculate_mock(self): + my_input = get_mock_linear_graph(5) + [result] = calculate.calculate_values([my_input]) self.assertIs(my_input._calculate_output, result) - def test_calculate_root_direct(self): - """Calculates the value of a node with no sources.""" - for options in options_to_test: - tree = create_expression.create_expression_from_prensor( - prensor_test_util.create_simple_prensor()) - [root_value] = calculate.calculate_values([tree], options=options) - self.assertAllEqual(root_value.size, 3) - - def test_calculate_root_indirect(self): - """Calculates the value of a node with one source.""" - for options in options_to_test: - tree = create_expression.create_expression_from_prensor( - prensor_test_util.create_simple_prensor()) - tree_2 = expression_add.add_paths(tree, {}) - [root_value] = calculate.calculate_values([tree_2], options=options) - self.assertAllEqual(root_value.size, 3) - - def test_calculate_tree_root_direct(self): - """Calculates the value of a tree with no sources.""" - for options in options_to_test: - tree = create_expression.create_expression_from_prensor( - prensor_test_util.create_simple_prensor()) - [new_expr] = calculate.calculate_prensors([tree], options=options) - self.assertAllEqual(new_expr.node.size, 3) - - def test_calculate_promote_anonymous(self): - """Performs promote_test.PromoteValuesTest, but with calculate_values.""" - for options in options_to_test: - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor()) - new_root, new_path = promote.promote_anonymous( - expr, path.Path(["user", "friends"])) - new_field = new_root.get_descendant_or_error(new_path) - [leaf_node] = calculate.calculate_values([new_field], options=options) - self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 1, 2]) - self.assertAllEqual(leaf_node.values, [b"a", b"b", b"c", b"d", b"e"]) - - def test_calculate_promote_named(self): - """Performs promote_test.PromoteValuesTest, but with calculate_values.""" - for options in options_to_test: - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor()) - new_root = promote.promote(expr, path.Path(["user", "friends"]), - "new_friends") - # projected = project.project(new_root, [path.Path(["new_friends"])]) - new_field = new_root.get_child_or_error("new_friends") - [leaf_node] = calculate.calculate_values([new_field], options=options) - self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 1, 2]) - self.assertAllEqual(leaf_node.values, [b"a", b"b", b"c", b"d", b"e"]) - - def test_create_query_and_calculate_event_value(self): - """Calculating a child value in a proto tests dependencies.""" - for options in options_to_test: - expr = proto_test_util._get_expression_from_session_empty_user_info() - [event_value - ] = calculate.calculate_values([expr.get_child_or_error("event")], - options=options) - self.assertAllEqual(event_value.parent_index, [0, 0, 0, 1, 1]) - - def test_create_query_modify_and_calculate_event_value(self): - """Calculating a child value in a proto tests dependencies.""" - for options in options_to_test: - root = proto_test_util._get_expression_from_session_empty_user_info() - root_2 = expression_add.add_paths( - root, {path.Path(["event_copy"]): root.get_child_or_error("event")}) - [event_value - ] = calculate.calculate_values([root_2.get_child_or_error("event_copy")], - options=options) - self.assertAllEqual(event_value.parent_index, [0, 0, 0, 1, 1]) + def test_calculate_broken_mock_is_repeated(self): + with self.assertRaisesRegex( + ValueError, + "Expression Node0 returned the wrong type: expected: repeated actual: optional .", + ): + single_node = expression_test_util.get_mock_broken_leaf( + True, tf.int64, False, tf.int64, name="Node0" + ) + calculate.calculate_values([single_node]) + + def test_calculate_broken_mock_dtype(self): + with self.assertRaisesRegex( + ValueError, + "Expression Node0 returned the " + "wrong type: expected: repeated actual: repeated .", + ): + single_node = expression_test_util.get_mock_broken_leaf( + True, tf.int64, True, tf.int32, name="Node0" + ) + calculate.calculate_values([single_node]) + + def test_calculate_random_mock(self): + """Test calculate on 1000 graphs with 20 nodes. + + This will have identity operations. + This will not have equal operations (outside of identity operations). + All nodes in the graph are leaf expressions. + """ + random.seed(a=12345) + for options in options_to_test: + for _ in range(1000): + my_input = create_random_graph(20, 0.5, 0.2) + [result] = calculate.calculate_values([my_input], options=options) + self.assertIs(my_input._calculate_output, result) + + def test_calculate_root_direct(self): + """Calculates the value of a node with no sources.""" + for options in options_to_test: + tree = create_expression.create_expression_from_prensor( + prensor_test_util.create_simple_prensor() + ) + [root_value] = calculate.calculate_values([tree], options=options) + self.assertAllEqual(root_value.size, 3) + + def test_calculate_root_indirect(self): + """Calculates the value of a node with one source.""" + for options in options_to_test: + tree = create_expression.create_expression_from_prensor( + prensor_test_util.create_simple_prensor() + ) + tree_2 = expression_add.add_paths(tree, {}) + [root_value] = calculate.calculate_values([tree_2], options=options) + self.assertAllEqual(root_value.size, 3) + + def test_calculate_tree_root_direct(self): + """Calculates the value of a tree with no sources.""" + for options in options_to_test: + tree = create_expression.create_expression_from_prensor( + prensor_test_util.create_simple_prensor() + ) + [new_expr] = calculate.calculate_prensors([tree], options=options) + self.assertAllEqual(new_expr.node.size, 3) + + def test_calculate_promote_anonymous(self): + """Performs promote_test.PromoteValuesTest, but with calculate_values.""" + for options in options_to_test: + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor() + ) + new_root, new_path = promote.promote_anonymous( + expr, path.Path(["user", "friends"]) + ) + new_field = new_root.get_descendant_or_error(new_path) + [leaf_node] = calculate.calculate_values([new_field], options=options) + self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 1, 2]) + self.assertAllEqual(leaf_node.values, [b"a", b"b", b"c", b"d", b"e"]) + + def test_calculate_promote_named(self): + """Performs promote_test.PromoteValuesTest, but with calculate_values.""" + for options in options_to_test: + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor() + ) + new_root = promote.promote( + expr, path.Path(["user", "friends"]), "new_friends" + ) + # projected = project.project(new_root, [path.Path(["new_friends"])]) + new_field = new_root.get_child_or_error("new_friends") + [leaf_node] = calculate.calculate_values([new_field], options=options) + self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 1, 2]) + self.assertAllEqual(leaf_node.values, [b"a", b"b", b"c", b"d", b"e"]) + + def test_create_query_and_calculate_event_value(self): + """Calculating a child value in a proto tests dependencies.""" + for options in options_to_test: + expr = proto_test_util._get_expression_from_session_empty_user_info() + [event_value] = calculate.calculate_values( + [expr.get_child_or_error("event")], options=options + ) + self.assertAllEqual(event_value.parent_index, [0, 0, 0, 1, 1]) + + def test_create_query_modify_and_calculate_event_value(self): + """Calculating a child value in a proto tests dependencies.""" + for options in options_to_test: + root = proto_test_util._get_expression_from_session_empty_user_info() + root_2 = expression_add.add_paths( + root, {path.Path(["event_copy"]): root.get_child_or_error("event")} + ) + [event_value] = calculate.calculate_values( + [root_2.get_child_or_error("event_copy")], options=options + ) + self.assertAllEqual(event_value.parent_index, [0, 0, 0, 1, 1]) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/calculate_with_source_paths.py b/struct2tensor/calculate_with_source_paths.py index 02bf714..247be63 100644 --- a/struct2tensor/calculate_with_source_paths.py +++ b/struct2tensor/calculate_with_source_paths.py @@ -29,98 +29,104 @@ # descriptor is descriptor.Descriptor # paths is List[path.Path] class ProtoRequirements(NamedTuple): - tensor: tf.Tensor - descriptor: descriptor.Descriptor - paths: List[path.Path] + tensor: tf.Tensor + descriptor: descriptor.Descriptor + paths: List[path.Path] def calculate_prensors_with_source_paths( trees: Sequence[expression.Expression], - options: Optional[calculate_options.Options] = None + options: Optional[calculate_options.Options] = None, ) -> Tuple[Sequence[prensor.Prensor], Sequence[ProtoRequirements]]: - """Returns a list of prensor trees, and proto summaries.""" - prensors, graph = calculate.calculate_prensors_with_graph( - trees, options=options) - proto_expressions = [ - x for x in graph.get_expressions_needed() if proto.is_proto_expression(x) - ] - summaries = _get_proto_summaries(proto_expressions) - return prensors, summaries + """Returns a list of prensor trees, and proto summaries.""" + prensors, graph = calculate.calculate_prensors_with_graph(trees, options=options) + proto_expressions = [ + x for x in graph.get_expressions_needed() if proto.is_proto_expression(x) + ] + summaries = _get_proto_summaries(proto_expressions) + return prensors, summaries def _dedup_paths(paths: Sequence[path.Path]) -> List[path.Path]: - """Deduplicates paths including prefixes. - - Args: - paths: a list of paths to dedup. - - Returns: - A new list l where: - each x in l is in paths - if x in l, then no prefix of x is in l (except x itself). - if x in paths, then there exists a y in l where x is a prefix of y. - """ - # Sort new_paths in increasing length. - # If x is a strict prefix of y, then x < y. - new_paths = sorted(paths, key=len) - new_paths.reverse() - # If x is a strict prefix of y, now x is after y. - result = set() # type: Set[path.Path] - result_and_prefixes = set() - for p in new_paths: - if p not in result_and_prefixes: - result.add(p) - for sub_path in [p.prefix(i) for i in range(len(p) + 1)]: - result_and_prefixes.add(sub_path) - - return list(result) + """Deduplicates paths including prefixes. + + Args: + paths: a list of paths to dedup. + + Returns: + A new list l where: + each x in l is in paths + if x in l, then no prefix of x is in l (except x itself). + if x in paths, then there exists a y in l where x is a prefix of y. + """ + # Sort new_paths in increasing length. + # If x is a strict prefix of y, then x < y. + new_paths = sorted(paths, key=len) + new_paths.reverse() + # If x is a strict prefix of y, now x is after y. + result = set() # type: Set[path.Path] + result_and_prefixes = set() + for p in new_paths: + if p not in result_and_prefixes: + result.add(p) + for sub_path in [p.prefix(i) for i in range(len(p) + 1)]: + result_and_prefixes.add(sub_path) + + return list(result) def requirements_to_metadata_proto( - inp: Sequence[ProtoRequirements], - result: query_metadata_pb2.QueryMetadata) -> None: - """Populates result with a proto representation of the ProtoRequirements. + inp: Sequence[ProtoRequirements], result: query_metadata_pb2.QueryMetadata +) -> None: + """Populates result with a proto representation of the ProtoRequirements. - This drops the link to the individual tensor of protos objects, which cannot - be meaningfully serialized. If there are two protos of the same type, there - will be two parsed_proto_info submessages with the same message_name. + This drops the link to the individual tensor of protos objects, which cannot + be meaningfully serialized. If there are two protos of the same type, there + will be two parsed_proto_info submessages with the same message_name. - Args: - inp: Sequence[ProtoRequirements] - result: a proto to be populated. - """ - for x in inp: - _requirement_to_parsed_proto_info(x, result.parsed_proto_info.add()) # pytype: disable=wrong-arg-types + Args: + inp: Sequence[ProtoRequirements] + result: a proto to be populated. + """ + for x in inp: + _requirement_to_parsed_proto_info( + x, result.parsed_proto_info.add() + ) # pytype: disable=wrong-arg-types def _get_proto_summaries( - proto_expressions: Sequence[proto.ProtoExpression] + proto_expressions: Sequence[proto.ProtoExpression], ) -> Sequence[ProtoRequirements]: - """Gets the proto summaries.""" - result = [] # type: List[ProtoRequirements] - for expr in proto_expressions: - - def get_summary(tensor_of_protos, desc): - for summary in result: - if id(summary.tensor) == id( - tensor_of_protos) and summary.descriptor == desc: - return summary - new_summary = ProtoRequirements( - tensor=tensor_of_protos, descriptor=desc, paths=[]) - result.append(new_summary) - return new_summary - - tensor_of_protos, desc = expr.get_proto_source() - get_summary(tensor_of_protos, desc).paths.append(expr.get_path()) - return [ - ProtoRequirements( - tensor=x.tensor, descriptor=x.descriptor, paths=_dedup_paths(x.paths)) - for x in result - ] + """Gets the proto summaries.""" + result = [] # type: List[ProtoRequirements] + for expr in proto_expressions: + + def get_summary(tensor_of_protos, desc): + for summary in result: + if ( + id(summary.tensor) == id(tensor_of_protos) + and summary.descriptor == desc + ): + return summary + new_summary = ProtoRequirements( + tensor=tensor_of_protos, descriptor=desc, paths=[] + ) + result.append(new_summary) + return new_summary + + tensor_of_protos, desc = expr.get_proto_source() + get_summary(tensor_of_protos, desc).paths.append(expr.get_path()) + return [ + ProtoRequirements( + tensor=x.tensor, descriptor=x.descriptor, paths=_dedup_paths(x.paths) + ) + for x in result + ] def _requirement_to_parsed_proto_info( - inp: ProtoRequirements, result: query_metadata_pb2.ParsedProtoInfo) -> None: - result.message_name = inp.descriptor.full_name - for p in inp.paths: - result.field_paths.add().MergeFrom(p.as_proto()) + inp: ProtoRequirements, result: query_metadata_pb2.ParsedProtoInfo +) -> None: + result.message_name = inp.descriptor.full_name + for p in inp.paths: + result.field_paths.add().MergeFrom(p.as_proto()) diff --git a/struct2tensor/calculate_with_source_paths_test.py b/struct2tensor/calculate_with_source_paths_test.py index 09feaca..af14e93 100644 --- a/struct2tensor/calculate_with_source_paths_test.py +++ b/struct2tensor/calculate_with_source_paths_test.py @@ -17,7 +17,7 @@ from absl.testing import absltest from google.protobuf import text_format from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor import calculate_with_source_paths, path @@ -28,125 +28,145 @@ @test_util.run_all_in_graph_and_eager_modes class CalculateWithSourcePathsTest(tf.test.TestCase): - - def assertLen(self, arr, expected_len): - self.assertEqual(len(arr), expected_len) - - # Doesn't check number of occurrences. - # E.g.: ["a", "a", "b"] == ["a", "b"] - def equal_ignore_order(self, a, b): - debug_string = "[{}] vs [{}]".format( - ",".join([f'"{str(ae)}"' for ae in a]), - ",".join([f'"{str(be)}"' for be in b])) - for ae in a: - self.assertIn(ae, b, debug_string) - for be in b: - self.assertIn(be, a, debug_string) - - def test_dedup_paths(self): - paths = [path.Path(["a", "b"]), path.Path(["a"])] - result = calculate_with_source_paths._dedup_paths(paths) - self.equal_ignore_order([path.Path(["a", "b"])], result) - - def test_calculate_prensors_with_source_paths(self): - """Tests get_sparse_tensors on a deep tree.""" - expr = proto_test_util._get_expression_from_session_empty_user_info() - # Let's make it non-trivial by transforming the data. - new_root = promote.promote(expr, path.Path(["event", "action", "doc_id"]), - "action_doc_ids") - # A poor-man's reroot. - new_field = new_root.get_descendant_or_error( - path.Path(["event", "action_doc_ids"])) - result = calculate_with_source_paths.calculate_prensors_with_source_paths( - [new_field]) - prensor_result, proto_summary_result = result - self.assertLen(prensor_result, 1) - self.assertLen(proto_summary_result, 1) - leaf_node = prensor_result[0].node - self.assertAllEqual(leaf_node.parent_index, [0, 0, 1, 2, 2, 3, 4, 4, 4]) - self.assertAllEqual(leaf_node.values, - [b"a", b"b", b"c", b"e", b"f", b"g", b"h", b"i", b"j"]) - list_of_paths = proto_summary_result[0].paths - expected = [path.Path(["event", "action", "doc_id"])] - self.equal_ignore_order(list_of_paths, expected) - - def test_calculate_prensors_with_source_paths_with_transform(self): - """Tests get_sparse_tensors on a deep tree with a transformed field.""" - expr = proto_test_util._get_expression_from_session_empty_user_info() - - # Let's make it non-trivial by transforming the data. - def _reverse(parent_indices, values): - return parent_indices, tf.reverse(values, axis=[-1]) - - expr = proto.create_transformed_field(expr, path.Path(["event"]), - "reversed_event", _reverse) - new_root = promote.promote( - expr, path.Path(["reversed_event", "action", "doc_id"]), - "action_doc_ids") - # A poor-man's reroot. - new_field = new_root.get_descendant_or_error( - path.Path(["reversed_event", "action_doc_ids"])) - result = calculate_with_source_paths.calculate_prensors_with_source_paths( - [new_field]) - prensor_result, proto_summary_result = result - self.assertLen(prensor_result, 1) - self.assertLen(proto_summary_result, 1) - leaf_node = prensor_result[0].node - self.assertAllEqual(leaf_node.parent_index, [0, 0, 0, 1, 2, 2, 3, 4, 4]) - self.assertAllEqual(leaf_node.values, - [b"h", b"i", b"j", b"g", b"e", b"f", b"c", b"a", b"b"]) - list_of_paths = proto_summary_result[0].paths - expected = [path.Path(["event", "action", "doc_id"])] - self.equal_ignore_order(list_of_paths, expected) - - def test_calculate_prensors_with_source_paths_with_multiple_transforms(self): - """Tests get_sparse_tensors on a deep tree with a transformed field.""" - expr = proto_test_util._get_expression_from_session_empty_user_info() - - # Let's make it non-trivial by transforming the data. - def _reverse(parent_indices, values): - return parent_indices, tf.reverse(values, axis=[-1]) - - expr = proto.create_transformed_field(expr, path.Path(["event"]), - "reversed_event", _reverse) - expr = proto.create_transformed_field(expr, path.Path(["reversed_event"]), - "reversed_reversed_event", _reverse) - new_root = promote.promote( - expr, path.Path(["reversed_reversed_event", "action", "doc_id"]), - "action_doc_ids") - # A poor-man's reroot. - new_field = new_root.get_descendant_or_error( - path.Path(["reversed_reversed_event", "action_doc_ids"])) - result = calculate_with_source_paths.calculate_prensors_with_source_paths( - [new_field]) - prensor_result, proto_summary_result = result - self.assertLen(prensor_result, 1) - self.assertLen(proto_summary_result, 1) - leaf_node = prensor_result[0].node - self.assertAllEqual(leaf_node.parent_index, [0, 0, 1, 2, 2, 3, 4, 4, 4]) - self.assertAllEqual(leaf_node.values, - [b"a", b"b", b"c", b"e", b"f", b"g", b"h", b"i", b"j"]) - list_of_paths = proto_summary_result[0].paths - expected = [path.Path(["event", "action", "doc_id"])] - self.equal_ignore_order(list_of_paths, expected) - - def test_requirements_to_metadata_proto(self): - proto_summary_result_0 = calculate_with_source_paths.ProtoRequirements( - None, test_pb2.Session.DESCRIPTOR, [ - path.Path(["event", "action", "doc_id"]), - path.Path(["event", "event_id"]) - ]) - proto_summary_result_1 = calculate_with_source_paths.ProtoRequirements( - None, test_pb2.Event.DESCRIPTOR, - [path.Path(["action", "doc_id"]), - path.Path(["event_id"])]) - - result = query_metadata_pb2.QueryMetadata() - calculate_with_source_paths.requirements_to_metadata_proto( - [proto_summary_result_0, proto_summary_result_1], result) - self.assertLen(result.parsed_proto_info, 2) - expected_result = text_format.Parse( - """ + def assertLen(self, arr, expected_len): + self.assertEqual(len(arr), expected_len) + + # Doesn't check number of occurrences. + # E.g.: ["a", "a", "b"] == ["a", "b"] + def equal_ignore_order(self, a, b): + debug_string = "[{}] vs [{}]".format( + ",".join([f'"{str(ae)}"' for ae in a]), + ",".join([f'"{str(be)}"' for be in b]), + ) + for ae in a: + self.assertIn(ae, b, debug_string) + for be in b: + self.assertIn(be, a, debug_string) + + def test_dedup_paths(self): + paths = [path.Path(["a", "b"]), path.Path(["a"])] + result = calculate_with_source_paths._dedup_paths(paths) + self.equal_ignore_order([path.Path(["a", "b"])], result) + + def test_calculate_prensors_with_source_paths(self): + """Tests get_sparse_tensors on a deep tree.""" + expr = proto_test_util._get_expression_from_session_empty_user_info() + # Let's make it non-trivial by transforming the data. + new_root = promote.promote( + expr, path.Path(["event", "action", "doc_id"]), "action_doc_ids" + ) + # A poor-man's reroot. + new_field = new_root.get_descendant_or_error( + path.Path(["event", "action_doc_ids"]) + ) + result = calculate_with_source_paths.calculate_prensors_with_source_paths( + [new_field] + ) + prensor_result, proto_summary_result = result + self.assertLen(prensor_result, 1) + self.assertLen(proto_summary_result, 1) + leaf_node = prensor_result[0].node + self.assertAllEqual(leaf_node.parent_index, [0, 0, 1, 2, 2, 3, 4, 4, 4]) + self.assertAllEqual( + leaf_node.values, [b"a", b"b", b"c", b"e", b"f", b"g", b"h", b"i", b"j"] + ) + list_of_paths = proto_summary_result[0].paths + expected = [path.Path(["event", "action", "doc_id"])] + self.equal_ignore_order(list_of_paths, expected) + + def test_calculate_prensors_with_source_paths_with_transform(self): + """Tests get_sparse_tensors on a deep tree with a transformed field.""" + expr = proto_test_util._get_expression_from_session_empty_user_info() + + # Let's make it non-trivial by transforming the data. + def _reverse(parent_indices, values): + return parent_indices, tf.reverse(values, axis=[-1]) + + expr = proto.create_transformed_field( + expr, path.Path(["event"]), "reversed_event", _reverse + ) + new_root = promote.promote( + expr, path.Path(["reversed_event", "action", "doc_id"]), "action_doc_ids" + ) + # A poor-man's reroot. + new_field = new_root.get_descendant_or_error( + path.Path(["reversed_event", "action_doc_ids"]) + ) + result = calculate_with_source_paths.calculate_prensors_with_source_paths( + [new_field] + ) + prensor_result, proto_summary_result = result + self.assertLen(prensor_result, 1) + self.assertLen(proto_summary_result, 1) + leaf_node = prensor_result[0].node + self.assertAllEqual(leaf_node.parent_index, [0, 0, 0, 1, 2, 2, 3, 4, 4]) + self.assertAllEqual( + leaf_node.values, [b"h", b"i", b"j", b"g", b"e", b"f", b"c", b"a", b"b"] + ) + list_of_paths = proto_summary_result[0].paths + expected = [path.Path(["event", "action", "doc_id"])] + self.equal_ignore_order(list_of_paths, expected) + + def test_calculate_prensors_with_source_paths_with_multiple_transforms(self): + """Tests get_sparse_tensors on a deep tree with a transformed field.""" + expr = proto_test_util._get_expression_from_session_empty_user_info() + + # Let's make it non-trivial by transforming the data. + def _reverse(parent_indices, values): + return parent_indices, tf.reverse(values, axis=[-1]) + + expr = proto.create_transformed_field( + expr, path.Path(["event"]), "reversed_event", _reverse + ) + expr = proto.create_transformed_field( + expr, path.Path(["reversed_event"]), "reversed_reversed_event", _reverse + ) + new_root = promote.promote( + expr, + path.Path(["reversed_reversed_event", "action", "doc_id"]), + "action_doc_ids", + ) + # A poor-man's reroot. + new_field = new_root.get_descendant_or_error( + path.Path(["reversed_reversed_event", "action_doc_ids"]) + ) + result = calculate_with_source_paths.calculate_prensors_with_source_paths( + [new_field] + ) + prensor_result, proto_summary_result = result + self.assertLen(prensor_result, 1) + self.assertLen(proto_summary_result, 1) + leaf_node = prensor_result[0].node + self.assertAllEqual(leaf_node.parent_index, [0, 0, 1, 2, 2, 3, 4, 4, 4]) + self.assertAllEqual( + leaf_node.values, [b"a", b"b", b"c", b"e", b"f", b"g", b"h", b"i", b"j"] + ) + list_of_paths = proto_summary_result[0].paths + expected = [path.Path(["event", "action", "doc_id"])] + self.equal_ignore_order(list_of_paths, expected) + + def test_requirements_to_metadata_proto(self): + proto_summary_result_0 = calculate_with_source_paths.ProtoRequirements( + None, + test_pb2.Session.DESCRIPTOR, + [ + path.Path(["event", "action", "doc_id"]), + path.Path(["event", "event_id"]), + ], + ) + proto_summary_result_1 = calculate_with_source_paths.ProtoRequirements( + None, + test_pb2.Event.DESCRIPTOR, + [path.Path(["action", "doc_id"]), path.Path(["event_id"])], + ) + + result = query_metadata_pb2.QueryMetadata() + calculate_with_source_paths.requirements_to_metadata_proto( + [proto_summary_result_0, proto_summary_result_1], result + ) + self.assertLen(result.parsed_proto_info, 2) + expected_result = text_format.Parse( + """ parsed_proto_info { message_name: "struct2tensor.test.Session" field_paths { @@ -168,38 +188,40 @@ def test_requirements_to_metadata_proto(self): field_paths { step: "event_id" } - }""", query_metadata_pb2.QueryMetadata()) - self.assertEqual(result, expected_result) - - def test_any_path(self): - my_any_0 = test_any_pb2.MessageWithAny() - my_value_0 = test_pb2.AllSimple() - my_value_0.optional_int32 = 17 - my_any_0.my_any.Pack(my_value_0) - expr = proto.create_expression_from_proto( - [my_any_0.SerializeToString()], test_any_pb2.MessageWithAny.DESCRIPTOR) - new_root = promote.promote( - expr, - path.Path( - ["my_any", "(struct2tensor.test.AllSimple)", "optional_int32"]), - "new_int32") - new_field = new_root.get_descendant_or_error( - path.Path(["my_any", "new_int32"])) - result = calculate_with_source_paths.calculate_prensors_with_source_paths( - [new_field]) - prensor_result, proto_summary_result = result - self.assertLen(prensor_result, 1) - self.assertLen(proto_summary_result, 1) - leaf_node = prensor_result[0].node - self.assertAllEqual(leaf_node.parent_index, [0]) - self.assertAllEqual(leaf_node.values, [17]) - list_of_paths = proto_summary_result[0].paths - expected = [ - path.Path( - ["my_any", "(struct2tensor.test.AllSimple)", "optional_int32"]) - ] - self.equal_ignore_order(list_of_paths, expected) + }""", + query_metadata_pb2.QueryMetadata(), + ) + self.assertEqual(result, expected_result) + + def test_any_path(self): + my_any_0 = test_any_pb2.MessageWithAny() + my_value_0 = test_pb2.AllSimple() + my_value_0.optional_int32 = 17 + my_any_0.my_any.Pack(my_value_0) + expr = proto.create_expression_from_proto( + [my_any_0.SerializeToString()], test_any_pb2.MessageWithAny.DESCRIPTOR + ) + new_root = promote.promote( + expr, + path.Path(["my_any", "(struct2tensor.test.AllSimple)", "optional_int32"]), + "new_int32", + ) + new_field = new_root.get_descendant_or_error(path.Path(["my_any", "new_int32"])) + result = calculate_with_source_paths.calculate_prensors_with_source_paths( + [new_field] + ) + prensor_result, proto_summary_result = result + self.assertLen(prensor_result, 1) + self.assertLen(proto_summary_result, 1) + leaf_node = prensor_result[0].node + self.assertAllEqual(leaf_node.parent_index, [0]) + self.assertAllEqual(leaf_node.values, [17]) + list_of_paths = proto_summary_result[0].paths + expected = [ + path.Path(["my_any", "(struct2tensor.test.AllSimple)", "optional_int32"]) + ] + self.equal_ignore_order(list_of_paths, expected) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/create_expression.py b/struct2tensor/create_expression.py index 18bc9d9..d285dd0 100644 --- a/struct2tensor/create_expression.py +++ b/struct2tensor/create_expression.py @@ -25,106 +25,104 @@ class _DirectExpression(expression.Expression): - """An expression where the value is immediate. + """An expression where the value is immediate. + + This expression has no sources, and a NodeTensor is set at construction time. + """ + + def __init__( + self, + is_repeated: bool, + my_type: Optional[tf.DType], + value: prensor.NodeTensor, + children: Mapping[path.Step, expression.Expression], + validate_step_format: bool, + ): + """Initializes an expression. + + Args: + is_repeated: if the expression is repeated. + my_type: the DType of a field, or None for an internal node. + value: the return value of calculate(...) + children: the subexpressions. + validate_step_format: If True, validates that steps do not have any + characters that could be ambiguously understood as structure delimiters + (e.g. "."). If False, such characters are allowed and the client is + responsible to ensure to not rely on any auto-coercion of strings to + paths. + """ + super().__init__( + is_repeated, my_type, validate_step_format=validate_step_format + ) + self._value = value + self._children = children + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + return self._value + + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + return isinstance(expr, _DirectExpression) and expr._value is self._value # pylint: disable=protected-access + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + return self._children.get(field_name) + + def known_field_names(self) -> FrozenSet[path.Step]: + return frozenset(self._children.keys()) + + def __str__(self) -> str: + return f"_DirectExpression: {str(id(self))}" - This expression has no sources, and a NodeTensor is set at construction time. - """ - def __init__( - self, - is_repeated: bool, - my_type: Optional[tf.DType], - value: prensor.NodeTensor, - children: Mapping[path.Step, expression.Expression], - validate_step_format: bool, - ): - """Initializes an expression. +def create_expression_from_prensor( + t: prensor.Prensor, validate_step_format: bool = True +) -> expression.Expression: + """Gets an expression representing the prensor. Args: - is_repeated: if the expression is repeated. - my_type: the DType of a field, or None for an internal node. - value: the return value of calculate(...) - children: the subexpressions. + t: The prensor to represent. validate_step_format: If True, validates that steps do not have any characters that could be ambiguously understood as structure delimiters (e.g. "."). If False, such characters are allowed and the client is responsible to ensure to not rely on any auto-coercion of strings to paths. - """ - super().__init__( - is_repeated, my_type, validate_step_format=validate_step_format - ) - self._value = value - self._children = children - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [] - - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - return self._value - - def calculation_is_identity(self) -> bool: - return False - def calculation_equal(self, expr: expression.Expression) -> bool: - return isinstance(expr, _DirectExpression) and expr._value is self._value # pylint: disable=protected-access - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - return self._children.get(field_name) - - def known_field_names(self) -> FrozenSet[path.Step]: - return frozenset(self._children.keys()) - - def __str__(self) -> str: - return f"_DirectExpression: {str(id(self))}" - - -def create_expression_from_prensor( - t: prensor.Prensor, validate_step_format: bool = True -) -> expression.Expression: - """Gets an expression representing the prensor. - - Args: - t: The prensor to represent. - validate_step_format: If True, validates that steps do not have any - characters that could be ambiguously understood as structure delimiters - (e.g. "."). If False, such characters are allowed and the client is - responsible to ensure to not rely on any auto-coercion of strings to - paths. - - Returns: - An expression representing the prensor. - """ - node_tensor = t.node - children = { - k: create_expression_from_prensor( - v, validate_step_format=validate_step_format - ) - for k, v in t.get_children().items() - } - if isinstance(node_tensor, prensor.RootNodeTensor): - return _DirectExpression( - True, None, node_tensor, children, validate_step_format - ) - if isinstance(node_tensor, prensor.ChildNodeTensor): + Returns: + An expression representing the prensor. + """ + node_tensor = t.node + children = { + k: create_expression_from_prensor(v, validate_step_format=validate_step_format) + for k, v in t.get_children().items() + } + if isinstance(node_tensor, prensor.RootNodeTensor): + return _DirectExpression( + True, None, node_tensor, children, validate_step_format + ) + if isinstance(node_tensor, prensor.ChildNodeTensor): + return _DirectExpression( + node_tensor.is_repeated, + None, + node_tensor, + children, + validate_step_format, + ) + # isinstance(node_tensor, LeafNodeTensor) return _DirectExpression( node_tensor.is_repeated, - None, + node_tensor.values.dtype, node_tensor, children, validate_step_format, ) - # isinstance(node_tensor, LeafNodeTensor) - return _DirectExpression( - node_tensor.is_repeated, - node_tensor.values.dtype, - node_tensor, - children, - validate_step_format, - ) diff --git a/struct2tensor/create_expression_test.py b/struct2tensor/create_expression_test.py index 9d08010..655edca 100644 --- a/struct2tensor/create_expression_test.py +++ b/struct2tensor/create_expression_test.py @@ -21,136 +21,127 @@ class ExpressionTest(absltest.TestCase): + def test_root_methods(self): + """Test all the basic operations on the root of a prensor expression.""" + expression = prensor_test_util.create_simple_prensor() + expr = create_expression.create_expression_from_prensor(expression) + self.assertTrue(expr.is_repeated) + self.assertIsNone(expr.type) + self.assertFalse(expr.is_leaf) + self.assertEqual(expr.get_source_expressions(), []) + self.assertFalse(expr.calculation_is_identity()) + self.assertTrue(expr.calculation_equal(expr)) + root_node = expression_test_util.calculate_value_slowly(expr) + self.assertEqual(root_node.size.dtype, tf.int64) + self.assertEqual(expr.known_field_names(), frozenset(["foo", "foorepeated"])) - def test_root_methods(self): - """Test all the basic operations on the root of a prensor expression.""" - expression = prensor_test_util.create_simple_prensor() - expr = create_expression.create_expression_from_prensor(expression) - self.assertTrue(expr.is_repeated) - self.assertIsNone(expr.type) - self.assertFalse(expr.is_leaf) - self.assertEqual(expr.get_source_expressions(), []) - self.assertFalse(expr.calculation_is_identity()) - self.assertTrue(expr.calculation_equal(expr)) - root_node = expression_test_util.calculate_value_slowly(expr) - self.assertEqual(root_node.size.dtype, tf.int64) - self.assertEqual(expr.known_field_names(), frozenset(["foo", - "foorepeated"])) + def test_foo_node(self): + """Test all the basic operations an optional leaf of an expression.""" + simple_prensor = prensor_test_util.create_simple_prensor() + expr = create_expression.create_expression_from_prensor(simple_prensor) + foo_expr = expr.get_child_or_error("foo") + self.assertFalse(foo_expr.is_repeated) + self.assertEqual(foo_expr.type, tf.int32) + self.assertTrue(foo_expr.is_leaf) + self.assertEqual(foo_expr.get_source_expressions(), []) + self.assertFalse(foo_expr.calculation_is_identity()) + self.assertTrue(foo_expr.calculation_equal(foo_expr)) + self.assertFalse(foo_expr.calculation_equal(expr)) + leaf_node = expression_test_util.calculate_value_slowly(foo_expr) + self.assertEqual(leaf_node.values.dtype, tf.int32) + self.assertEqual(foo_expr.known_field_names(), frozenset()) - def test_foo_node(self): - """Test all the basic operations an optional leaf of an expression.""" - simple_prensor = prensor_test_util.create_simple_prensor() - expr = create_expression.create_expression_from_prensor(simple_prensor) - foo_expr = expr.get_child_or_error("foo") - self.assertFalse(foo_expr.is_repeated) - self.assertEqual(foo_expr.type, tf.int32) - self.assertTrue(foo_expr.is_leaf) - self.assertEqual(foo_expr.get_source_expressions(), []) - self.assertFalse(foo_expr.calculation_is_identity()) - self.assertTrue(foo_expr.calculation_equal(foo_expr)) - self.assertFalse(foo_expr.calculation_equal(expr)) - leaf_node = expression_test_util.calculate_value_slowly(foo_expr) - self.assertEqual(leaf_node.values.dtype, tf.int32) - self.assertEqual(foo_expr.known_field_names(), frozenset()) + def test_foorepeated_node(self): + """Test all the basic operations an repeated leaf of an expression.""" + expression = prensor_test_util.create_simple_prensor() + expr = create_expression.create_expression_from_prensor(expression) + foorepeated_expr = expr.get_child_or_error("foorepeated") + self.assertTrue(foorepeated_expr.is_repeated) + self.assertEqual(foorepeated_expr.type, tf.int32) + self.assertTrue(foorepeated_expr.is_leaf) + self.assertEqual(foorepeated_expr.get_source_expressions(), []) + self.assertFalse(foorepeated_expr.calculation_is_identity()) + self.assertTrue(foorepeated_expr.calculation_equal(foorepeated_expr)) + self.assertFalse(foorepeated_expr.calculation_equal(expr)) + leaf_node = expression_test_util.calculate_value_slowly(foorepeated_expr) + self.assertEqual(leaf_node.values.dtype, tf.int32) + self.assertEqual(foorepeated_expr.known_field_names(), frozenset()) - def test_foorepeated_node(self): - """Test all the basic operations an repeated leaf of an expression.""" - expression = prensor_test_util.create_simple_prensor() - expr = create_expression.create_expression_from_prensor(expression) - foorepeated_expr = expr.get_child_or_error("foorepeated") - self.assertTrue(foorepeated_expr.is_repeated) - self.assertEqual(foorepeated_expr.type, tf.int32) - self.assertTrue(foorepeated_expr.is_leaf) - self.assertEqual(foorepeated_expr.get_source_expressions(), []) - self.assertFalse(foorepeated_expr.calculation_is_identity()) - self.assertTrue(foorepeated_expr.calculation_equal(foorepeated_expr)) - self.assertFalse(foorepeated_expr.calculation_equal(expr)) - leaf_node = expression_test_util.calculate_value_slowly(foorepeated_expr) - self.assertEqual(leaf_node.values.dtype, tf.int32) - self.assertEqual(foorepeated_expr.known_field_names(), frozenset()) + def test_get_descendants(self): + expression = prensor_test_util.create_nested_prensor() + expr = create_expression.create_expression_from_prensor(expression) + expr_user_friends = expr.get_descendant_or_error(path.Path(["user", "friends"])) + self.assertIs( + expr_user_friends, + expr.get_child_or_error("user").get_child_or_error("friends"), + ) - def test_get_descendants(self): - expression = prensor_test_util.create_nested_prensor() - expr = create_expression.create_expression_from_prensor(expression) - expr_user_friends = expr.get_descendant_or_error( - path.Path(["user", "friends"])) - self.assertIs(expr_user_friends, - expr.get_child_or_error("user").get_child_or_error("friends")) + def test_get_known_descendants(self): + expression = prensor_test_util.create_nested_prensor() + expr = create_expression.create_expression_from_prensor(expression) + expr_map = expr.get_known_descendants() + self.assertIn(path.Path(["doc"]), expr_map) + self.assertIn(path.Path(["doc", "bar"]), expr_map) + self.assertIn(path.Path(["doc", "keep_me"]), expr_map) + self.assertIn(path.Path(["user"]), expr_map) + self.assertIn(path.Path(["user", "friends"]), expr_map) - def test_get_known_descendants(self): - expression = prensor_test_util.create_nested_prensor() - expr = create_expression.create_expression_from_prensor(expression) - expr_map = expr.get_known_descendants() - self.assertIn(path.Path(["doc"]), expr_map) - self.assertIn(path.Path(["doc", "bar"]), expr_map) - self.assertIn(path.Path(["doc", "keep_me"]), expr_map) - self.assertIn(path.Path(["user"]), expr_map) - self.assertIn(path.Path(["user", "friends"]), expr_map) + def test_lenient_formatting_get_known_descendants(self): + expression = prensor_test_util.create_nested_prensor_with_lenient_field_names() + expr = create_expression.create_expression_from_prensor( + expression, validate_step_format=False + ) + expr_map = expr.get_known_descendants() + self.assertIn(path.Path(["doc"], validate_step_format=False), expr_map) + self.assertIn( + path.Path(["doc", "bar.baz"], validate_step_format=False), expr_map + ) + self.assertIn( + path.Path(["doc", "keep_me/x"], validate_step_format=False), expr_map + ) + self.assertIn(path.Path(["user"], validate_step_format=False), expr_map) + self.assertIn( + path.Path(["user", "friends!:)"], validate_step_format=False), expr_map + ) - def test_lenient_formatting_get_known_descendants(self): - expression = ( - prensor_test_util.create_nested_prensor_with_lenient_field_names() - ) - expr = create_expression.create_expression_from_prensor( - expression, validate_step_format=False - ) - expr_map = expr.get_known_descendants() - self.assertIn(path.Path(["doc"], validate_step_format=False), expr_map) - self.assertIn( - path.Path(["doc", "bar.baz"], validate_step_format=False), expr_map - ) - self.assertIn( - path.Path(["doc", "keep_me/x"], validate_step_format=False), expr_map - ) - self.assertIn(path.Path(["user"], validate_step_format=False), expr_map) - self.assertIn( - path.Path(["user", "friends!:)"], validate_step_format=False), expr_map - ) + def test_lenient_formatting_map_sparse_tensors(self): + expression = prensor_test_util.create_nested_prensor_with_lenient_field_names() + expr = create_expression.create_expression_from_prensor( + expression, validate_step_format=False + ) + expr_map = expr.map_sparse_tensors( + "doc", + ["bar.baz", "keep_me/x"], + tf.compat.v1.sparse_add, + False, + tf.int64, + "total", + ) + self.assertIsNotNone(expr_map) - def test_lenient_formatting_map_sparse_tensors(self): - expression = ( - prensor_test_util.create_nested_prensor_with_lenient_field_names() - ) - expr = create_expression.create_expression_from_prensor( - expression, validate_step_format=False - ) - expr_map = expr.map_sparse_tensors( - "doc", - ["bar.baz", "keep_me/x"], - tf.compat.v1.sparse_add, - False, - tf.int64, - "total", - ) - self.assertIsNotNone(expr_map) + def test_lenient_formatting_map_ragged_tensors(self): + expression = prensor_test_util.create_nested_prensor_with_lenient_field_names() + expr = create_expression.create_expression_from_prensor( + expression, validate_step_format=False + ) + expr_map = expr.map_ragged_tensors( + "doc", + ["bar.baz", "keep_me/x"], + lambda x: x, + False, + tf.int64, + "total", + ) + self.assertIsNotNone(expr_map) - def test_lenient_formatting_map_ragged_tensors(self): - expression = ( - prensor_test_util.create_nested_prensor_with_lenient_field_names() - ) - expr = create_expression.create_expression_from_prensor( - expression, validate_step_format=False - ) - expr_map = expr.map_ragged_tensors( - "doc", - ["bar.baz", "keep_me/x"], - lambda x: x, - False, - tf.int64, - "total", - ) - self.assertIsNotNone(expr_map) - - def test_lenient_formatting_get_paths_with_schema(self): - expression = ( - prensor_test_util.create_nested_prensor_with_lenient_field_names() - ) - expr = create_expression.create_expression_from_prensor( - expression, validate_step_format=False - ) - paths = expr.get_paths_with_schema() - self.assertLen(paths, 1) + def test_lenient_formatting_get_paths_with_schema(self): + expression = prensor_test_util.create_nested_prensor_with_lenient_field_names() + expr = create_expression.create_expression_from_prensor( + expression, validate_step_format=False + ) + paths = expr.get_paths_with_schema() + self.assertLen(paths, 1) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression.py b/struct2tensor/expression.py index 2d90790..9623bb0 100644 --- a/struct2tensor/expression.py +++ b/struct2tensor/expression.py @@ -34,7 +34,7 @@ import tensorflow as tf from tensorflow.python.util.lazy_loader import ( - LazyLoader, # pylint: disable=g-direct-tensorflow-import + LazyLoader, # pylint: disable=g-direct-tensorflow-import ) from tensorflow_metadata.proto.v0 import schema_pb2 @@ -48,34 +48,39 @@ # upon this one. # Similar to: # from struct2tensor.expression_impl import promote -promote = LazyLoader("promote", globals(), - "struct2tensor.expression_impl.promote") +promote = LazyLoader("promote", globals(), "struct2tensor.expression_impl.promote") -broadcast = LazyLoader("broadcast", globals(), - "struct2tensor.expression_impl.broadcast") +broadcast = LazyLoader( + "broadcast", globals(), "struct2tensor.expression_impl.broadcast" +) promote_and_broadcast = LazyLoader( - "promote_and_broadcast", globals(), "struct2tensor.expression_impl" - ".promote_and_broadcast") + "promote_and_broadcast", + globals(), + "struct2tensor.expression_impl.promote_and_broadcast", +) -map_values = LazyLoader("map_values", globals(), - "struct2tensor.expression_impl.map_values") +map_values = LazyLoader( + "map_values", globals(), "struct2tensor.expression_impl.map_values" +) -project = LazyLoader("project", globals(), - "struct2tensor.expression_impl.project") +project = LazyLoader("project", globals(), "struct2tensor.expression_impl.project") size = LazyLoader("size", globals(), "struct2tensor.expression_impl.size") reroot = LazyLoader("reroot", globals(), "struct2tensor.expression_impl.reroot") -map_prensor = LazyLoader("map_prensor", globals(), - "struct2tensor.expression_impl.map_prensor") +map_prensor = LazyLoader( + "map_prensor", globals(), "struct2tensor.expression_impl.map_prensor" +) -apply_schema = LazyLoader("apply_schema", globals(), - "struct2tensor.expression_impl.apply_schema") +apply_schema = LazyLoader( + "apply_schema", globals(), "struct2tensor.expression_impl.apply_schema" +) -slice_expression = LazyLoader("slice_expression", globals(), - "struct2tensor.expression_impl.slice_expression") +slice_expression = LazyLoader( + "slice_expression", globals(), "struct2tensor.expression_impl.slice_expression" +) # Type for limit arguments to slice (begin, end). # Union[int, tf.Tensor, tf.Variable] @@ -83,586 +88,620 @@ class Expression(metaclass=abc.ABCMeta): - """An expression represents the calculation of a prensor object.""" - - def __init__( - self, - is_repeated: bool, - my_type: Optional[tf.DType], - schema_feature: Optional[schema_pb2.Feature] = None, - validate_step_format: bool = True, - ): - """Initialize an expression. - - Args: - is_repeated: if the expression is repeated. - my_type: the DType of a field, or None for an internal node. - schema_feature: the local schema (StructDomain information should not be - present). - validate_step_format: If True, validates that steps do not have any - characters that could be ambiguously understood as structure delimiters - (e.g. "."). If False, such characters are allowed and the client is - responsible to ensure to not rely on any auto-coercion of strings to - paths. - """ - self._is_repeated = is_repeated - self._type = my_type - self._child_cache = {} - self._schema_feature = schema_feature - self._validate_step_format = validate_step_format - - @property - def is_repeated(self) -> bool: - """True iff the same parent value can have multiple children values.""" - return self._is_repeated - - @property - def type(self) -> Optional[tf.DType]: - """Dtype of the expression, or None if not a leaf expression.""" - return self._type - - @property - def is_leaf(self) -> bool: - """True iff the node tensor is a LeafNodeTensor.""" - return self.type is not None - - @property - def schema_feature(self) -> Optional[schema_pb2.Feature]: - """Return the schema of the field.""" - return self._schema_feature - - @property - def validate_step_format(self) -> bool: - return self._validate_step_format - - @abc.abstractmethod - def get_source_expressions(self) -> Sequence["Expression"]: - """Gets the sources of this expression. - - The node tensors of the source expressions must be sufficient to - calculate the node tensor of this expression - (see calculate and calculate_value_slowly). - - Returns: - The sources of this expression. - """ - raise NotImplementedError() - - @abc.abstractmethod - def calculate( - self, - source_tensors: Sequence[prensor.NodeTensor], - destinations: Sequence["Expression"], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - """Calculates the node tensor of the expression. - - The node tensor must be a function of the properties of the expression - and the node tensors of the expressions from get_source_expressions(). - - If is_leaf, then calculate must return a LeafNodeTensor. - Otherwise, it must return a ChildNodeTensor or RootNodeTensor. - - If calculate_is_identity is true, then this must return source_tensors[0]. - - Sometimes, for operations such as parsing the proto, calculate will return - additional information. For example, calculate() for the root of the - proto expression also parses out the tensors required to calculate the - tensors of the children. This is why destinations are required. - - For a reference use, see calculate_value_slowly(...) below. - - Args: - source_tensors: The node tensors of the expressions in - get_source_expressions(). - destinations: The expressions that will use the output of this method. - options: Options for the calculation. - side_info: An optional prensor that is used to bind to a placeholder - expression. - - Returns: - A NodeTensor representing the output of this expression. - """ - raise NotImplementedError() - - @abc.abstractmethod - def calculation_is_identity(self) -> bool: - """True iff the self.calculate is the identity. - - There is exactly one source, and the output of self.calculate(...) is the - node tensor of this source. - """ - raise NotImplementedError() - - @abc.abstractmethod - def calculation_equal(self, expression: "Expression") -> bool: - """self.calculate is equal to another expression.calculate. - - Given the same source node tensors, self.calculate(...) and - expression.calculate(...) will have the same result. - - Note that this does not check that the source expressions of the two - expressions are the same. Therefore, two operations can have the same - calculation, but not the same output, because their sources are different. - For example, if a.calculation_is_identity() is True and - b.calculation_is_identity() is True, then a.calculation_equal(b) is True. - However, unless a and b have the same source, the expressions themselves are - not equal. - - Args: - expression: The expression to compare to. - """ - raise NotImplementedError() - - @abc.abstractmethod - def _get_child_impl(self, field_name: path.Step) -> Optional["Expression"]: - """Implementation of getting a named child in a subclass. - - This is called and cached inside get_child(). - - Args: - field_name: the field accessed. - - Returns: - The expression of the field, or None if the field does not exist. - """ - raise NotImplementedError() - - @abc.abstractmethod - def known_field_names(self) -> FrozenSet[path.Step]: - """Returns known field names of the expression. - - TODO(martinz): implement set_field and project. - Known field names of a parsed proto correspond to the fields declared in - the message. Examples of "unknown" fields are extensions and explicit casts - in an any field. The only way to know if an unknown field "(foo.bar)" is - present in an expression expr is to call (expr["(foo.bar)"] is not None). - - Notice that simply accessing a field does not make it "known". However, - setting a field (or setting a descendant of a field) will make it known. - - project(...) returns an expression where the known field names are the only - field names. In general, if you want to depend upon known_field_names - (e.g., if you want to compile a expression), then the best approach is to - project() the expression first. - - Returns: - An immutable set of field names. - """ - - def get_child(self, field_name: path.Step) -> Optional["Expression"]: - """Gets a named child.""" - if field_name in self._child_cache: - return self._child_cache[field_name] - result = self._get_child_impl(field_name) - self._child_cache[field_name] = result - return result - - def get_child_or_error(self, field_name: path.Step) -> "Expression": - """Gets a named child.""" - result = self.get_child(field_name) - if result is None: - raise KeyError(f"No such field: {field_name}") - return result - - def get_descendant(self, p: path.Path) -> Optional["Expression"]: - """Finds the descendant at the path.""" - result = self - for field_name in p.field_list: - result = result.get_child(field_name) - if result is None: - return None - return result - - def get_descendant_or_error(self, p: path.Path) -> "Expression": - """Finds the descendant at the path.""" - result = self.get_descendant(p) - if result is None: - raise ValueError(f"Missing path: {str(p)} in {self.schema_string(limit=20)}") - return result - - def get_known_children(self) -> Mapping[path.Step, "Expression"]: - known_field_names = self.known_field_names() - result = {} - for name in known_field_names: - result[name] = self.get_child_or_error(name) - return result - - def get_known_descendants(self) -> Mapping[path.Path, "Expression"]: - # Rename get_known_descendants - """Gets a mapping from known paths to subexpressions. - - The difference between this and get_descendants in Prensor is that - all paths in a Prensor are realized, thus all known. But an Expression's - descendants might not all be known at the point this method is called, - because an expression may have an infinite number of children. - - Returns: - A mapping from paths (relative to the root of the subexpression) to - expressions. - """ - known_subexpressions = { - k: v.get_known_descendants() - for k, v in self.get_known_children().items() - } - result = {} - for field_name, subexpression in known_subexpressions.items(): - subexpression_path = path.Path( - [field_name], validate_step_format=self.validate_step_format - ) - for p, expr in subexpression.items(): - result[subexpression_path.concat(p)] = expr - result[path.Path([], validate_step_format=self.validate_step_format)] = self - return result - - def _schema_string_helper(self, field_name: path.Step, - limit: Optional[int]) -> List[str]: - """Helper for schema_string.""" - repeated_as_string = "repeated" if self.is_repeated else "optional" - if self.type is None: - result = [f"{repeated_as_string} {str(field_name)}:"] - else: - result = [ - f"{repeated_as_string} {str(self.type)} {str(field_name)}" - ] - if limit is not None and limit == 0: - if self.get_known_children(): - result.append(" ...") - return result - - new_limit = None if limit is None else limit - 1 - - for field_name, subexpression in self.get_known_children().items(): - recursive = subexpression._schema_string_helper(field_name, new_limit) # pylint: disable=protected-access - result.extend([f" {x}" for x in recursive]) - return result - - # Begin methods compatible with v1 API. ###################################### - # TODO(martinz): Implement cogroup_by_index. - def map_sparse_tensors(self, parent_path: CoercableToPath, - source_fields: Sequence[path.Step], - operator: Callable[..., tf.SparseTensor], - is_repeated: bool, dtype: tf.DType, - new_field_name: path.Step) -> "Expression": - """Maps a set of primitive fields of a message to a new field. - - Unlike map_field_values, this operation allows you to some degree reshape - the field. For instance, you can take two optional fields and create a - repeated field, or perform a reduce_sum on the last dimension of a repeated - field and create an optional field. The key constraint is that the operator - must return a sparse tensor of the correct dimension: i.e., a - 2D sparse tensor if is_repeated is true, or a 1D sparse tensor if - is_repeated is false. Moreover, the first dimension of the sparse tensor - must be equal to the first dimension of the input tensor. - - Args: - parent_path: the parent of the input and output fields. - source_fields: the nonempty list of names of the source fields. - operator: an operator that takes len(source_fields) sparse tensors and - returns a sparse tensor of the appropriate shape. - is_repeated: whether the output is repeated. - dtype: the dtype of the result. - new_field_name: the name of the resulting field. - - Returns: - A new query. - """ - return map_prensor.map_sparse_tensor( + """An expression represents the calculation of a prensor object.""" + + def __init__( + self, + is_repeated: bool, + my_type: Optional[tf.DType], + schema_feature: Optional[schema_pb2.Feature] = None, + validate_step_format: bool = True, + ): + """Initialize an expression. + + Args: + is_repeated: if the expression is repeated. + my_type: the DType of a field, or None for an internal node. + schema_feature: the local schema (StructDomain information should not be + present). + validate_step_format: If True, validates that steps do not have any + characters that could be ambiguously understood as structure delimiters + (e.g. "."). If False, such characters are allowed and the client is + responsible to ensure to not rely on any auto-coercion of strings to + paths. + """ + self._is_repeated = is_repeated + self._type = my_type + self._child_cache = {} + self._schema_feature = schema_feature + self._validate_step_format = validate_step_format + + @property + def is_repeated(self) -> bool: + """True iff the same parent value can have multiple children values.""" + return self._is_repeated + + @property + def type(self) -> Optional[tf.DType]: + """Dtype of the expression, or None if not a leaf expression.""" + return self._type + + @property + def is_leaf(self) -> bool: + """True iff the node tensor is a LeafNodeTensor.""" + return self.type is not None + + @property + def schema_feature(self) -> Optional[schema_pb2.Feature]: + """Return the schema of the field.""" + return self._schema_feature + + @property + def validate_step_format(self) -> bool: + return self._validate_step_format + + @abc.abstractmethod + def get_source_expressions(self) -> Sequence["Expression"]: + """Gets the sources of this expression. + + The node tensors of the source expressions must be sufficient to + calculate the node tensor of this expression + (see calculate and calculate_value_slowly). + + Returns: + The sources of this expression. + """ + raise NotImplementedError() + + @abc.abstractmethod + def calculate( + self, + source_tensors: Sequence[prensor.NodeTensor], + destinations: Sequence["Expression"], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + """Calculates the node tensor of the expression. + + The node tensor must be a function of the properties of the expression + and the node tensors of the expressions from get_source_expressions(). + + If is_leaf, then calculate must return a LeafNodeTensor. + Otherwise, it must return a ChildNodeTensor or RootNodeTensor. + + If calculate_is_identity is true, then this must return source_tensors[0]. + + Sometimes, for operations such as parsing the proto, calculate will return + additional information. For example, calculate() for the root of the + proto expression also parses out the tensors required to calculate the + tensors of the children. This is why destinations are required. + + For a reference use, see calculate_value_slowly(...) below. + + Args: + source_tensors: The node tensors of the expressions in + get_source_expressions(). + destinations: The expressions that will use the output of this method. + options: Options for the calculation. + side_info: An optional prensor that is used to bind to a placeholder + expression. + + Returns: + A NodeTensor representing the output of this expression. + """ + raise NotImplementedError() + + @abc.abstractmethod + def calculation_is_identity(self) -> bool: + """True iff the self.calculate is the identity. + + There is exactly one source, and the output of self.calculate(...) is the + node tensor of this source. + """ + raise NotImplementedError() + + @abc.abstractmethod + def calculation_equal(self, expression: "Expression") -> bool: + """self.calculate is equal to another expression.calculate. + + Given the same source node tensors, self.calculate(...) and + expression.calculate(...) will have the same result. + + Note that this does not check that the source expressions of the two + expressions are the same. Therefore, two operations can have the same + calculation, but not the same output, because their sources are different. + For example, if a.calculation_is_identity() is True and + b.calculation_is_identity() is True, then a.calculation_equal(b) is True. + However, unless a and b have the same source, the expressions themselves are + not equal. + + Args: + expression: The expression to compare to. + """ + raise NotImplementedError() + + @abc.abstractmethod + def _get_child_impl(self, field_name: path.Step) -> Optional["Expression"]: + """Implementation of getting a named child in a subclass. + + This is called and cached inside get_child(). + + Args: + field_name: the field accessed. + + Returns: + The expression of the field, or None if the field does not exist. + """ + raise NotImplementedError() + + @abc.abstractmethod + def known_field_names(self) -> FrozenSet[path.Step]: + """Returns known field names of the expression. + + TODO(martinz): implement set_field and project. + Known field names of a parsed proto correspond to the fields declared in + the message. Examples of "unknown" fields are extensions and explicit casts + in an any field. The only way to know if an unknown field "(foo.bar)" is + present in an expression expr is to call (expr["(foo.bar)"] is not None). + + Notice that simply accessing a field does not make it "known". However, + setting a field (or setting a descendant of a field) will make it known. + + project(...) returns an expression where the known field names are the only + field names. In general, if you want to depend upon known_field_names + (e.g., if you want to compile a expression), then the best approach is to + project() the expression first. + + Returns: + An immutable set of field names. + """ + + def get_child(self, field_name: path.Step) -> Optional["Expression"]: + """Gets a named child.""" + if field_name in self._child_cache: + return self._child_cache[field_name] + result = self._get_child_impl(field_name) + self._child_cache[field_name] = result + return result + + def get_child_or_error(self, field_name: path.Step) -> "Expression": + """Gets a named child.""" + result = self.get_child(field_name) + if result is None: + raise KeyError(f"No such field: {field_name}") + return result + + def get_descendant(self, p: path.Path) -> Optional["Expression"]: + """Finds the descendant at the path.""" + result = self + for field_name in p.field_list: + result = result.get_child(field_name) + if result is None: + return None + return result + + def get_descendant_or_error(self, p: path.Path) -> "Expression": + """Finds the descendant at the path.""" + result = self.get_descendant(p) + if result is None: + raise ValueError( + f"Missing path: {str(p)} in {self.schema_string(limit=20)}" + ) + return result + + def get_known_children(self) -> Mapping[path.Step, "Expression"]: + known_field_names = self.known_field_names() + result = {} + for name in known_field_names: + result[name] = self.get_child_or_error(name) + return result + + def get_known_descendants(self) -> Mapping[path.Path, "Expression"]: + # Rename get_known_descendants + """Gets a mapping from known paths to subexpressions. + + The difference between this and get_descendants in Prensor is that + all paths in a Prensor are realized, thus all known. But an Expression's + descendants might not all be known at the point this method is called, + because an expression may have an infinite number of children. + + Returns: + A mapping from paths (relative to the root of the subexpression) to + expressions. + """ + known_subexpressions = { + k: v.get_known_descendants() for k, v in self.get_known_children().items() + } + result = {} + for field_name, subexpression in known_subexpressions.items(): + subexpression_path = path.Path( + [field_name], validate_step_format=self.validate_step_format + ) + for p, expr in subexpression.items(): + result[subexpression_path.concat(p)] = expr + result[path.Path([], validate_step_format=self.validate_step_format)] = self + return result + + def _schema_string_helper( + self, field_name: path.Step, limit: Optional[int] + ) -> List[str]: + """Helper for schema_string.""" + repeated_as_string = "repeated" if self.is_repeated else "optional" + if self.type is None: + result = [f"{repeated_as_string} {str(field_name)}:"] + else: + result = [f"{repeated_as_string} {str(self.type)} {str(field_name)}"] + if limit is not None and limit == 0: + if self.get_known_children(): + result.append(" ...") + return result + + new_limit = None if limit is None else limit - 1 + + for field_name, subexpression in self.get_known_children().items(): + recursive = subexpression._schema_string_helper(field_name, new_limit) # pylint: disable=protected-access + result.extend([f" {x}" for x in recursive]) + return result + + # Begin methods compatible with v1 API. ###################################### + # TODO(martinz): Implement cogroup_by_index. + def map_sparse_tensors( + self, + parent_path: CoercableToPath, + source_fields: Sequence[path.Step], + operator: Callable[..., tf.SparseTensor], + is_repeated: bool, + dtype: tf.DType, + new_field_name: path.Step, + ) -> "Expression": + """Maps a set of primitive fields of a message to a new field. + + Unlike map_field_values, this operation allows you to some degree reshape + the field. For instance, you can take two optional fields and create a + repeated field, or perform a reduce_sum on the last dimension of a repeated + field and create an optional field. The key constraint is that the operator + must return a sparse tensor of the correct dimension: i.e., a + 2D sparse tensor if is_repeated is true, or a 1D sparse tensor if + is_repeated is false. Moreover, the first dimension of the sparse tensor + must be equal to the first dimension of the input tensor. + + Args: + parent_path: the parent of the input and output fields. + source_fields: the nonempty list of names of the source fields. + operator: an operator that takes len(source_fields) sparse tensors and + returns a sparse tensor of the appropriate shape. + is_repeated: whether the output is repeated. + dtype: the dtype of the result. + new_field_name: the name of the resulting field. + + Returns: + A new query. + """ + return map_prensor.map_sparse_tensor( + self, + path.create_path(parent_path), + [ + path.Path([f], validate_step_format=self.validate_step_format) + for f in source_fields + ], + operator, + is_repeated, + dtype, + new_field_name, + ) + + def map_ragged_tensors( self, - path.create_path(parent_path), - [ - path.Path([f], validate_step_format=self.validate_step_format) - for f in source_fields - ], - operator, - is_repeated, - dtype, - new_field_name, - ) - - def map_ragged_tensors(self, parent_path: CoercableToPath, - source_fields: Sequence[path.Step], - operator: Callable[..., tf.SparseTensor], - is_repeated: bool, dtype: tf.DType, - new_field_name: path.Step) -> "Expression": - """Maps a set of primitive fields of a message to a new field. - - Unlike map_field_values, this operation allows you to some degree reshape - the field. For instance, you can take two optional fields and create a - repeated field, or perform a reduce_sum on the last dimension of a repeated - field and create an optional field. The key constraint is that the operator - must return a sparse tensor of the correct dimension: i.e., a - 2D sparse tensor if is_repeated is true, or a 1D sparse tensor if - is_repeated is false. Moreover, the first dimension of the sparse tensor - must be equal to the first dimension of the input tensor. - - Args: - parent_path: the parent of the input and output fields. - source_fields: the nonempty list of names of the source fields. - operator: an operator that takes len(source_fields) sparse tensors and - returns a sparse tensor of the appropriate shape. - is_repeated: whether the output is repeated. - dtype: the dtype of the result. - new_field_name: the name of the resulting field. - - Returns: - A new query. - """ - return map_prensor.map_ragged_tensor( + parent_path: CoercableToPath, + source_fields: Sequence[path.Step], + operator: Callable[..., tf.SparseTensor], + is_repeated: bool, + dtype: tf.DType, + new_field_name: path.Step, + ) -> "Expression": + """Maps a set of primitive fields of a message to a new field. + + Unlike map_field_values, this operation allows you to some degree reshape + the field. For instance, you can take two optional fields and create a + repeated field, or perform a reduce_sum on the last dimension of a repeated + field and create an optional field. The key constraint is that the operator + must return a sparse tensor of the correct dimension: i.e., a + 2D sparse tensor if is_repeated is true, or a 1D sparse tensor if + is_repeated is false. Moreover, the first dimension of the sparse tensor + must be equal to the first dimension of the input tensor. + + Args: + parent_path: the parent of the input and output fields. + source_fields: the nonempty list of names of the source fields. + operator: an operator that takes len(source_fields) sparse tensors and + returns a sparse tensor of the appropriate shape. + is_repeated: whether the output is repeated. + dtype: the dtype of the result. + new_field_name: the name of the resulting field. + + Returns: + A new query. + """ + return map_prensor.map_ragged_tensor( + self, + path.create_path(parent_path), + [ + path.Path([f], validate_step_format=self.validate_step_format) + for f in source_fields + ], + operator, + is_repeated, + dtype, + new_field_name, + ) + + def truncate( self, - path.create_path(parent_path), - [ - path.Path([f], validate_step_format=self.validate_step_format) - for f in source_fields - ], - operator, - is_repeated, - dtype, - new_field_name, - ) - - def truncate(self, source_path: CoercableToPath, limit: Union[int, tf.Tensor], - new_field_name: path.Step) -> "Expression": - """Creates a truncated copy of source_path at new_field_path.""" - return self.slice(source_path, new_field_name, end=limit) - - def slice(self, - source_path: CoercableToPath, - new_field_name: path.Step, - begin: Optional[IndexValue] = None, - end: Optional[IndexValue] = None) -> "Expression": - """Creates a slice copy of source_path at new_field_path. - - Note that if begin or end is negative, it is considered relative to - the size of the array. e.g., slice(...,begin=-1) will get the last - element of every array. - - Args: - source_path: the source of the slice. - new_field_name: the new field that is generated. - begin: the beginning of the slice (inclusive). - end: the end of the slice (exclusive). - - Returns: - An Expression object representing the result of the operation. - """ - return slice_expression.slice_expression(self, - path.create_path(source_path), - new_field_name, begin, end) - - def promote(self, source_path: CoercableToPath, new_field_name: path.Step): - """Promotes source_path to be a field new_field_name in its grandparent.""" - return promote.promote(self, path.create_path(source_path), new_field_name) - - def broadcast(self, source_path: CoercableToPath, sibling_field: path.Step, - new_field_name: path.Step) -> "Expression": - """Broadcasts the existing field at source_path to the sibling_field.""" - return broadcast.broadcast(self, path.create_path(source_path), - sibling_field, new_field_name) - - def project(self, path_list: Sequence[CoercableToPath]) -> "Expression": - """Constrains the paths to those listed.""" - return project.project(self, [path.create_path(x) for x in path_list]) - - def promote_and_broadcast( - self, path_dictionary: Mapping[path.Step, CoercableToPath], - dest_path_parent: CoercableToPath) -> "Expression": - return promote_and_broadcast.promote_and_broadcast( - self, {k: path.create_path(v) for k, v in path_dictionary.items()}, - path.create_path(dest_path_parent)) - - def map_field_values(self, source_path: CoercableToPath, - operator: Callable[[tf.Tensor], tf.Tensor], - dtype: tf.DType, - new_field_name: path.Step) -> "Expression": - """Map a primitive field to create a new primitive field. - - Note: the dtype argument is added since the v1 API. - - Args: - source_path: the origin path. - operator: an element-wise operator that takes a 1-dimensional vector. - dtype: the type of the output. - new_field_name: the name of a new sibling of source_path. - - Returns: - the resulting root expression. - """ - return map_values.map_values(self, path.create_path(source_path), operator, - dtype, new_field_name) - - def reroot(self, new_root: CoercableToPath) -> "Expression": - """Returns a new list of protocol buffers available at new_root.""" - return reroot.reroot(self, path.create_path(new_root)) - - def create_size_field(self, source_path: CoercableToPath, - new_field_name: path.Step) -> "Expression": - """Creates a field that is the size of the source path.""" - return size.size(self, path.create_path(source_path), new_field_name) - - def create_has_field(self, source_path: CoercableToPath, - new_field_name: path.Step) -> "Expression": - """Creates a field that is the presence of the source path.""" - return size.has(self, path.create_path(source_path), new_field_name) - - def create_proto_index(self, field_name: path.Step) -> "Expression": - """Creates a proto index field as a direct child of the current root. - - The proto index maps each root element to the original batch index. - For example: [0, 2] means the first element came from the first proto - in the original input tensor and the second element came from the third - proto. The created field is always "dense" -- it has the same valency as - the current root. - - Args: - field_name: the name of the field to be created. - - Returns: - An Expression object representing the result of the operation. - """ - return reroot.create_proto_index_field(self, field_name) - - def cogroup_by_index(self, source_path: CoercableToPath, left_name: path.Step, - right_name: path.Step, - new_field_name: path.Step) -> "Expression": - """Creates a cogroup of left_name and right_name at new_field_name.""" - raise NotImplementedError("cogroup_by_index is not implemented") - - # End methods compatible with v1 API. ######################################## - - def apply(self, - transform: Callable[["Expression"], "Expression"]) -> "Expression": - return transform(self) - - def apply_schema(self, schema: schema_pb2.Schema) -> "Expression": - return apply_schema.apply_schema(self, schema) - - def get_paths_with_schema(self) -> List[path.Path]: - """Extract only paths that contain schema information.""" - result = [] - for name, child in self.get_known_children().items(): - if child.schema_feature is None: - continue - result.extend( - [ - path.Path( - [name], validate_step_format=self.validate_step_format - ).concat(x) - for x in child.get_paths_with_schema() - ] - ) - # Note: We always take the root path and so will return an empty schema - # if there is no schema information on any nodes, including the root. - if not result: - result.append( - path.Path([], validate_step_format=self.validate_step_format) - ) - return result - - def _populate_schema_feature_children(self, feature_list) -> None: - """Populate a feature list from the children of this node. - - The argument is a protobuf repeated field that is populated. The names - of the features come from the leaf names. - - Args: - feature_list: RepeatedCompositeFieldContainer of schema_pb2.Feature. - """ - for name, child in self.get_known_children().items(): - new_feature = feature_list.add() - if child.schema_feature is None: - if not child.is_repeated: - new_feature.value_count.max = 1 - else: - new_feature.CopyFrom(child.schema_feature) - if child.get_known_children(): - new_feature.type = schema_pb2.FeatureType.STRUCT - child._populate_schema_feature_children( # pylint:disable=protected-access - new_feature.struct_domain.feature) - new_feature.name = name - - def get_schema(self, create_schema_features=True) -> schema_pb2.Schema: - """Returns a schema for the entire tree. - - Args: - create_schema_features: If True, schema features are added for all - children and a schema entry is created if not available on the child. If - False, features are left off of the returned schema if there is no - schema_feature on the child. - """ - if not create_schema_features: - return self.project(self.get_paths_with_schema()).get_schema() - result = schema_pb2.Schema() - self._populate_schema_feature_children(result.feature) - return result - - def schema_string(self, limit: Optional[int] = None) -> str: - """Returns a schema for the expression. - - E.g. - - repeated root: - optional int32 foo - optional bar: - optional string baz - optional int64 bak - - Note that unknown fields and subexpressions are not displayed. - - Args: - limit: if present, limit the recursion. - - Returns: - A string, describing (a part of) the schema. - """ - return "\n".join(self._schema_string_helper("root", limit)) - - def __hash__(self) -> int: - """Returns the id of the expression as the hash. - - Do not override this method. - """ - return id(self) - - def __eq__(self, expr: "Expression") -> bool: - """If hash(expr1) == hash(expr2): then expr1 == expr2. - - Do not override this method. - - Args: - expr: The expression to check equality against - - Returns: - Boolean of equality of two expressions - """ - return id(self) == id(expr) - - def __str__(self) -> str: - """If not overridden, returns the schema string.""" - return self.schema_string(limit=20) + source_path: CoercableToPath, + limit: Union[int, tf.Tensor], + new_field_name: path.Step, + ) -> "Expression": + """Creates a truncated copy of source_path at new_field_path.""" + return self.slice(source_path, new_field_name, end=limit) + + def slice( + self, + source_path: CoercableToPath, + new_field_name: path.Step, + begin: Optional[IndexValue] = None, + end: Optional[IndexValue] = None, + ) -> "Expression": + """Creates a slice copy of source_path at new_field_path. + + Note that if begin or end is negative, it is considered relative to + the size of the array. e.g., slice(...,begin=-1) will get the last + element of every array. + + Args: + source_path: the source of the slice. + new_field_name: the new field that is generated. + begin: the beginning of the slice (inclusive). + end: the end of the slice (exclusive). + + Returns: + An Expression object representing the result of the operation. + """ + return slice_expression.slice_expression( + self, path.create_path(source_path), new_field_name, begin, end + ) + + def promote(self, source_path: CoercableToPath, new_field_name: path.Step): + """Promotes source_path to be a field new_field_name in its grandparent.""" + return promote.promote(self, path.create_path(source_path), new_field_name) + + def broadcast( + self, + source_path: CoercableToPath, + sibling_field: path.Step, + new_field_name: path.Step, + ) -> "Expression": + """Broadcasts the existing field at source_path to the sibling_field.""" + return broadcast.broadcast( + self, path.create_path(source_path), sibling_field, new_field_name + ) + + def project(self, path_list: Sequence[CoercableToPath]) -> "Expression": + """Constrains the paths to those listed.""" + return project.project(self, [path.create_path(x) for x in path_list]) + + def promote_and_broadcast( + self, + path_dictionary: Mapping[path.Step, CoercableToPath], + dest_path_parent: CoercableToPath, + ) -> "Expression": + return promote_and_broadcast.promote_and_broadcast( + self, + {k: path.create_path(v) for k, v in path_dictionary.items()}, + path.create_path(dest_path_parent), + ) + + def map_field_values( + self, + source_path: CoercableToPath, + operator: Callable[[tf.Tensor], tf.Tensor], + dtype: tf.DType, + new_field_name: path.Step, + ) -> "Expression": + """Map a primitive field to create a new primitive field. + + Note: the dtype argument is added since the v1 API. + + Args: + source_path: the origin path. + operator: an element-wise operator that takes a 1-dimensional vector. + dtype: the type of the output. + new_field_name: the name of a new sibling of source_path. + + Returns: + the resulting root expression. + """ + return map_values.map_values( + self, path.create_path(source_path), operator, dtype, new_field_name + ) + + def reroot(self, new_root: CoercableToPath) -> "Expression": + """Returns a new list of protocol buffers available at new_root.""" + return reroot.reroot(self, path.create_path(new_root)) + + def create_size_field( + self, source_path: CoercableToPath, new_field_name: path.Step + ) -> "Expression": + """Creates a field that is the size of the source path.""" + return size.size(self, path.create_path(source_path), new_field_name) + + def create_has_field( + self, source_path: CoercableToPath, new_field_name: path.Step + ) -> "Expression": + """Creates a field that is the presence of the source path.""" + return size.has(self, path.create_path(source_path), new_field_name) + + def create_proto_index(self, field_name: path.Step) -> "Expression": + """Creates a proto index field as a direct child of the current root. + + The proto index maps each root element to the original batch index. + For example: [0, 2] means the first element came from the first proto + in the original input tensor and the second element came from the third + proto. The created field is always "dense" -- it has the same valency as + the current root. + + Args: + field_name: the name of the field to be created. + + Returns: + An Expression object representing the result of the operation. + """ + return reroot.create_proto_index_field(self, field_name) + + def cogroup_by_index( + self, + source_path: CoercableToPath, + left_name: path.Step, + right_name: path.Step, + new_field_name: path.Step, + ) -> "Expression": + """Creates a cogroup of left_name and right_name at new_field_name.""" + raise NotImplementedError("cogroup_by_index is not implemented") + + # End methods compatible with v1 API. ######################################## + + def apply(self, transform: Callable[["Expression"], "Expression"]) -> "Expression": + return transform(self) + + def apply_schema(self, schema: schema_pb2.Schema) -> "Expression": + return apply_schema.apply_schema(self, schema) + + def get_paths_with_schema(self) -> List[path.Path]: + """Extract only paths that contain schema information.""" + result = [] + for name, child in self.get_known_children().items(): + if child.schema_feature is None: + continue + result.extend( + [ + path.Path( + [name], validate_step_format=self.validate_step_format + ).concat(x) + for x in child.get_paths_with_schema() + ] + ) + # Note: We always take the root path and so will return an empty schema + # if there is no schema information on any nodes, including the root. + if not result: + result.append(path.Path([], validate_step_format=self.validate_step_format)) + return result + + def _populate_schema_feature_children(self, feature_list) -> None: + """Populate a feature list from the children of this node. + + The argument is a protobuf repeated field that is populated. The names + of the features come from the leaf names. + + Args: + feature_list: RepeatedCompositeFieldContainer of schema_pb2.Feature. + """ + for name, child in self.get_known_children().items(): + new_feature = feature_list.add() + if child.schema_feature is None: + if not child.is_repeated: + new_feature.value_count.max = 1 + else: + new_feature.CopyFrom(child.schema_feature) + if child.get_known_children(): + new_feature.type = schema_pb2.FeatureType.STRUCT + child._populate_schema_feature_children( # pylint:disable=protected-access + new_feature.struct_domain.feature + ) + new_feature.name = name + + def get_schema(self, create_schema_features=True) -> schema_pb2.Schema: + """Returns a schema for the entire tree. + + Args: + create_schema_features: If True, schema features are added for all + children and a schema entry is created if not available on the child. If + False, features are left off of the returned schema if there is no + schema_feature on the child. + """ + if not create_schema_features: + return self.project(self.get_paths_with_schema()).get_schema() + result = schema_pb2.Schema() + self._populate_schema_feature_children(result.feature) + return result + + def schema_string(self, limit: Optional[int] = None) -> str: + """Returns a schema for the expression. + + E.g. + + repeated root: + optional int32 foo + optional bar: + optional string baz + optional int64 bak + + Note that unknown fields and subexpressions are not displayed. + + Args: + limit: if present, limit the recursion. + + Returns: + A string, describing (a part of) the schema. + """ + return "\n".join(self._schema_string_helper("root", limit)) + + def __hash__(self) -> int: + """Returns the id of the expression as the hash. + + Do not override this method. + """ + return id(self) + + def __eq__(self, expr: "Expression") -> bool: + """If hash(expr1) == hash(expr2): then expr1 == expr2. + + Do not override this method. + + Args: + expr: The expression to check equality against + + Returns: + Boolean of equality of two expressions + """ + return id(self) == id(expr) + + def __str__(self) -> str: + """If not overridden, returns the schema string.""" + return self.schema_string(limit=20) # TODO(martinz): this will be tested when broadcast is created. class Leaf(Expression): - """An abstract supertype for expression subtypes without any children.""" + """An abstract supertype for expression subtypes without any children.""" - def __init__(self, - is_repeated: bool, - my_type: tf.DType, - schema_feature: Optional[schema_pb2.Feature] = None): # pylint: disable=useless-super-delegation - """Initialize a Leaf. - - Note that a leaf must have a specified type. - - Args: - is_repeated: if the expression is repeated. - my_type: the DType of the field. - schema_feature: schema information about the field. - """ - super().__init__(is_repeated, my_type, schema_feature=schema_feature) - - def _get_child_impl(self, field_name: path.Step) -> Optional["Expression"]: - return None + def __init__( + self, + is_repeated: bool, + my_type: tf.DType, + schema_feature: Optional[schema_pb2.Feature] = None, + ): # pylint: disable=useless-super-delegation + """Initialize a Leaf. + + Note that a leaf must have a specified type. + + Args: + is_repeated: if the expression is repeated. + my_type: the DType of the field. + schema_feature: schema information about the field. + """ + super().__init__(is_repeated, my_type, schema_feature=schema_feature) + + def _get_child_impl(self, field_name: path.Step) -> Optional["Expression"]: + return None - def known_field_names(self) -> FrozenSet[path.Step]: - return frozenset() + def known_field_names(self) -> FrozenSet[path.Step]: + return frozenset() diff --git a/struct2tensor/expression_add.py b/struct2tensor/expression_add.py index 547d64b..e65cd14 100644 --- a/struct2tensor/expression_add.py +++ b/struct2tensor/expression_add.py @@ -28,167 +28,175 @@ def create_subtrees( - path_map: Mapping[path.Path, expression.Expression] -) -> Tuple[Optional[expression.Expression], - Mapping[path.Step, Mapping[path.Path, expression.Expression]]]: - """Breaks a tree into a root node and dictionary of subtrees.""" - subtrees = { - } # type:Mapping[path.Step, Mapping[path.Path, expression.Expression]] - root_expression = None # type:Optional[expression.Expression] - for k, v in path_map.items(): - if not k: - root_expression = v - else: - first_step = k.field_list[0] - suffix = k.suffix(1) - if first_step not in subtrees: - subtrees[first_step] = {} # pytype: disable=unsupported-operands - subtrees[first_step][suffix] = v # pytype: disable=unsupported-operands - return (root_expression, subtrees) + path_map: Mapping[path.Path, expression.Expression], +) -> Tuple[ + Optional[expression.Expression], + Mapping[path.Step, Mapping[path.Path, expression.Expression]], +]: + """Breaks a tree into a root node and dictionary of subtrees.""" + subtrees = {} # type:Mapping[path.Step, Mapping[path.Path, expression.Expression]] + root_expression = None # type:Optional[expression.Expression] + for k, v in path_map.items(): + if not k: + root_expression = v + else: + first_step = k.field_list[0] + suffix = k.suffix(1) + if first_step not in subtrees: + subtrees[first_step] = {} # pytype: disable=unsupported-operands + subtrees[first_step][suffix] = v # pytype: disable=unsupported-operands + return (root_expression, subtrees) class _AddPathsExpression(expression.Expression): - """An expression that is the result of add_paths, adding new fields. - - _AddPathsExpression is an overlay on top of the original expression with - specified paths being added. - """ - - def __init__( - self, origin: expression.Expression, - path_map: Mapping[path.Step, Mapping[path.Path, expression.Expression]]): - super().__init__( - origin.is_repeated, - origin.type, - schema_feature=origin.schema_feature, - validate_step_format=origin.validate_step_format, - ) - self._origin = origin - self._path_map = path_map - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._origin] - - def calculate(self, source_values: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: Options) -> prensor.NodeTensor: - if len(source_values) != 1: - raise ValueError("Expected one source.") - return source_values[0] - - def calculation_is_identity(self) -> bool: - return True - - def calculation_equal(self, expr: expression.Expression) -> bool: - return expr.calculation_is_identity() - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - child_from_origin = self._origin.get_child(field_name) - path_map = self._path_map.get(field_name) - if path_map is None: - return child_from_origin - set_root_expr, subtrees = create_subtrees(path_map) - if child_from_origin is None: - if set_root_expr is None: - raise ValueError("Must have a value in the original if there are paths") - if subtrees: - return _AddPathsExpression(set_root_expr, subtrees) - return set_root_expr - if set_root_expr is not None: - raise ValueError("Tried to overwrite an existing expression") - return _AddPathsExpression(child_from_origin, subtrees) - - def known_field_names(self) -> FrozenSet[path.Step]: - return self._origin.known_field_names().union(self._path_map.keys()) - - def __str__(self) -> str: - keys_to_add = ",".join([str(k) for k in self._path_map]) - return f"_AddPathsExpression({str(self._origin)}, [{keys_to_add}])" - - -def add_paths(root: expression.Expression, - path_map: Mapping[path.Path, expression.Expression]): - """Creates a new expression based on `root` with paths in `path_map` added. - - This operation should be used with care: e.g., there is no guarantee that - the new expressions make any sense in the place they are put in the tree. It - is useful when wrapping a new Expression type, but should not be used by - end users. - - Prefer add_to to add_paths. - - Args: - root: the root of the tree. - path_map: a map from a path to the new subtree. - - Returns: - a new tree with the nodes from the root and the new subtrees. - """ - for p in path_map: - if root.get_descendant(p.get_parent()) is None: - raise ValueError(f"No parent of {p}") - if root.get_descendant(p) is not None: - raise ValueError(f"Path already set: {str(p)}") - _, map_of_maps = create_subtrees(path_map) - return _AddPathsExpression(root, map_of_maps) - - -def _is_true_source_expression(candidate: expression.Expression, - dest: expression.Expression) -> bool: - """True if dest is an expression derived from candidate through add_paths. - - More precisely, true if dest is the result of zero or more add_paths - operations on candidate, where the (source of)* dest is candidate. - - A "true source" for a node dest has an identical NodeTensor to dest, and the - parent of "true source" is a "true source" for the dest. This means if I have - a child X of dest, I could make a new tree where X is a child of the source. - - See add_to. - - Args: - candidate: the possible source - dest: the possible destination. - - Returns: - True iff source is a true source expression of dest. - """ - if dest is candidate: - return True - if isinstance(dest, _AddPathsExpression): - return _is_true_source_expression(candidate, dest._origin) # pylint: disable=protected-access - return False - - -def add_to(root: expression.Expression, - origins: Mapping[path.Path, expression.Expression]): - """Copies subtrees from the origins to the root. - - This operation can be used to reduce the number of expressions in the graph. - 1. The path must be present in the associated origin. - 2. The root must not have any of the paths already. - 3. The root must already have the parent of the path. - 4. The parent of the path in the root must be a source expression of the - parent of the path in the origin. - - Args: - root: the original tree that has new expressions added to it. - origins: mapping from path to trees that have subtrees at the path. - - Returns: - A tree with the root and the additional subtrees. - """ - for p, origin_root in origins.items(): - path_parent = p.get_parent() - if not _is_true_source_expression( - root.get_descendant_or_error(path_parent), - origin_root.get_descendant_or_error(path_parent)): - raise ValueError(f"Not a true source for tree with {str(p)}") - if root.get_descendant(p) is not None: - raise ValueError(f"Already contains {str(p)}.") - path_map = { - p: origin_root.get_descendant_or_error(p) - for p, origin_root in origins.items() - } - return add_paths(root, path_map) + """An expression that is the result of add_paths, adding new fields. + + _AddPathsExpression is an overlay on top of the original expression with + specified paths being added. + """ + + def __init__( + self, + origin: expression.Expression, + path_map: Mapping[path.Step, Mapping[path.Path, expression.Expression]], + ): + super().__init__( + origin.is_repeated, + origin.type, + schema_feature=origin.schema_feature, + validate_step_format=origin.validate_step_format, + ) + self._origin = origin + self._path_map = path_map + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._origin] + + def calculate( + self, + source_values: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: Options, + ) -> prensor.NodeTensor: + if len(source_values) != 1: + raise ValueError("Expected one source.") + return source_values[0] + + def calculation_is_identity(self) -> bool: + return True + + def calculation_equal(self, expr: expression.Expression) -> bool: + return expr.calculation_is_identity() + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + child_from_origin = self._origin.get_child(field_name) + path_map = self._path_map.get(field_name) + if path_map is None: + return child_from_origin + set_root_expr, subtrees = create_subtrees(path_map) + if child_from_origin is None: + if set_root_expr is None: + raise ValueError("Must have a value in the original if there are paths") + if subtrees: + return _AddPathsExpression(set_root_expr, subtrees) + return set_root_expr + if set_root_expr is not None: + raise ValueError("Tried to overwrite an existing expression") + return _AddPathsExpression(child_from_origin, subtrees) + + def known_field_names(self) -> FrozenSet[path.Step]: + return self._origin.known_field_names().union(self._path_map.keys()) + + def __str__(self) -> str: + keys_to_add = ",".join([str(k) for k in self._path_map]) + return f"_AddPathsExpression({str(self._origin)}, [{keys_to_add}])" + + +def add_paths( + root: expression.Expression, path_map: Mapping[path.Path, expression.Expression] +): + """Creates a new expression based on `root` with paths in `path_map` added. + + This operation should be used with care: e.g., there is no guarantee that + the new expressions make any sense in the place they are put in the tree. It + is useful when wrapping a new Expression type, but should not be used by + end users. + + Prefer add_to to add_paths. + + Args: + root: the root of the tree. + path_map: a map from a path to the new subtree. + + Returns: + a new tree with the nodes from the root and the new subtrees. + """ + for p in path_map: + if root.get_descendant(p.get_parent()) is None: + raise ValueError(f"No parent of {p}") + if root.get_descendant(p) is not None: + raise ValueError(f"Path already set: {str(p)}") + _, map_of_maps = create_subtrees(path_map) + return _AddPathsExpression(root, map_of_maps) + + +def _is_true_source_expression( + candidate: expression.Expression, dest: expression.Expression +) -> bool: + """True if dest is an expression derived from candidate through add_paths. + + More precisely, true if dest is the result of zero or more add_paths + operations on candidate, where the (source of)* dest is candidate. + + A "true source" for a node dest has an identical NodeTensor to dest, and the + parent of "true source" is a "true source" for the dest. This means if I have + a child X of dest, I could make a new tree where X is a child of the source. + + See add_to. + + Args: + candidate: the possible source + dest: the possible destination. + + Returns: + True iff source is a true source expression of dest. + """ + if dest is candidate: + return True + if isinstance(dest, _AddPathsExpression): + return _is_true_source_expression(candidate, dest._origin) # pylint: disable=protected-access + return False + + +def add_to( + root: expression.Expression, origins: Mapping[path.Path, expression.Expression] +): + """Copies subtrees from the origins to the root. + + This operation can be used to reduce the number of expressions in the graph. + 1. The path must be present in the associated origin. + 2. The root must not have any of the paths already. + 3. The root must already have the parent of the path. + 4. The parent of the path in the root must be a source expression of the + parent of the path in the origin. + + Args: + root: the original tree that has new expressions added to it. + origins: mapping from path to trees that have subtrees at the path. + + Returns: + A tree with the root and the additional subtrees. + """ + for p, origin_root in origins.items(): + path_parent = p.get_parent() + if not _is_true_source_expression( + root.get_descendant_or_error(path_parent), + origin_root.get_descendant_or_error(path_parent), + ): + raise ValueError(f"Not a true source for tree with {str(p)}") + if root.get_descendant(p) is not None: + raise ValueError(f"Already contains {str(p)}.") + path_map = { + p: origin_root.get_descendant_or_error(p) for p, origin_root in origins.items() + } + return add_paths(root, path_map) diff --git a/struct2tensor/expression_add_test.py b/struct2tensor/expression_add_test.py index c41d52d..cadc204 100644 --- a/struct2tensor/expression_add_test.py +++ b/struct2tensor/expression_add_test.py @@ -21,87 +21,102 @@ class ModifyTest(absltest.TestCase): + def test_add_paths(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor() + ) + new_root = expression_add.add_paths( + expr, + { + path.Path(["user", "friends_copy"]): expr.get_descendant_or_error( + path.Path(["user", "friends"]) + ) + }, + ) + new_field = new_root.get_descendant_or_error( + path.Path(["user", "friends_copy"]) + ) + self.assertIsNotNone(new_field) + self.assertTrue(new_field.is_repeated) + self.assertEqual(new_field.type, tf.string) + self.assertTrue(new_field.is_leaf) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertEqual(leaf_node.values.dtype, tf.string) + self.assertEqual(new_field.known_field_names(), frozenset()) - def test_add_paths(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor()) - new_root = expression_add.add_paths( - expr, { - path.Path(["user", "friends_copy"]): - expr.get_descendant_or_error(path.Path(["user", "friends"])) - }) - new_field = new_root.get_descendant_or_error( - path.Path(["user", "friends_copy"])) - self.assertIsNotNone(new_field) - self.assertTrue(new_field.is_repeated) - self.assertEqual(new_field.type, tf.string) - self.assertTrue(new_field.is_leaf) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertEqual(leaf_node.values.dtype, tf.string) - self.assertEqual(new_field.known_field_names(), frozenset()) + def test_add_to(self): + root = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor() + ) + root_1 = expression_add.add_paths( + root, + { + path.Path(["user", "friends_2"]): root.get_descendant_or_error( + path.Path(["user", "friends"]) + ) + }, + ) + root_2 = expression_add.add_paths( + root_1, + { + path.Path(["user", "friends_3"]): root_1.get_descendant_or_error( + path.Path(["user", "friends_2"]) + ) + }, + ) + root_3 = expression_add.add_to(root, {path.Path(["user", "friends_3"]): root_2}) - def test_add_to(self): - root = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor()) - root_1 = expression_add.add_paths( - root, { - path.Path(["user", "friends_2"]): - root.get_descendant_or_error(path.Path(["user", "friends"])) - }) - root_2 = expression_add.add_paths( - root_1, { - path.Path(["user", "friends_3"]): - root_1.get_descendant_or_error( - path.Path(["user", "friends_2"])) - }) - root_3 = expression_add.add_to(root, - {path.Path(["user", "friends_3"]): root_2}) + new_field = root_3.get_descendant_or_error(path.Path(["user", "friends_3"])) + self.assertIsNotNone(new_field) + self.assertTrue(new_field.is_repeated) + self.assertEqual(new_field.type, tf.string) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertEqual(leaf_node.values.dtype, tf.string) - new_field = root_3.get_descendant_or_error(path.Path(["user", "friends_3"])) - self.assertIsNotNone(new_field) - self.assertTrue(new_field.is_repeated) - self.assertEqual(new_field.type, tf.string) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertEqual(leaf_node.values.dtype, tf.string) - - def test_add_to_already_existing_path(self): - with self.assertRaises(ValueError): - root = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor()) - root_1 = expression_add.add_paths( - root, { - path.Path(["user", "friends_2"]): - root.get_descendant_or_error(path.Path(["user", "friends"])) - }) - root_2 = expression_add.add_paths( - root_1, { - path.Path(["user", "friends_3"]): - root_1.get_descendant_or_error( - path.Path(["user", "friends_2"])) - }) - expression_add.add_to(root, {path.Path(["user", "friends"]): root_2}) - - def test_add_to_with_lenient_field_names(self): - root = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor_with_lenient_field_names(), - validate_step_format=False, - ) - root_1 = expression_add.add_paths( - root, - { - path.Path( - ["user", "friends_2$"], validate_step_format=False - ): root.get_descendant_or_error( - path.Path(["user", "friends!:)"], validate_step_format=False) + def test_add_to_already_existing_path(self): + with self.assertRaises(ValueError): + root = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor() + ) + root_1 = expression_add.add_paths( + root, + { + path.Path(["user", "friends_2"]): root.get_descendant_or_error( + path.Path(["user", "friends"]) + ) + }, ) - }, - ) + root_2 = expression_add.add_paths( + root_1, + { + path.Path(["user", "friends_3"]): root_1.get_descendant_or_error( + path.Path(["user", "friends_2"]) + ) + }, + ) + expression_add.add_to(root, {path.Path(["user", "friends"]): root_2}) + + def test_add_to_with_lenient_field_names(self): + root = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor_with_lenient_field_names(), + validate_step_format=False, + ) + root_1 = expression_add.add_paths( + root, + { + path.Path( + ["user", "friends_2$"], validate_step_format=False + ): root.get_descendant_or_error( + path.Path(["user", "friends!:)"], validate_step_format=False) + ) + }, + ) - new_field = root_1.get_descendant_or_error( - path.Path(["user", "friends_2$"], validate_step_format=False) - ) - self.assertIsNotNone(new_field) + new_field = root_1.get_descendant_or_error( + path.Path(["user", "friends_2$"], validate_step_format=False) + ) + self.assertIsNotNone(new_field) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/apply_schema.py b/struct2tensor/expression_impl/apply_schema.py index fe65909..74b39e2 100644 --- a/struct2tensor/expression_impl/apply_schema.py +++ b/struct2tensor/expression_impl/apply_schema.py @@ -54,137 +54,143 @@ from struct2tensor import calculate_options, expression, path, prensor -def apply_schema(expr: expression.Expression, - schema: schema_pb2.Schema) -> expression.Expression: - schema_copy = schema_pb2.Schema() - schema_copy.CopyFrom(schema) - for x in schema_copy.feature: - _normalize_feature(x, schema_copy) - return _SchemaExpression(expr, schema_copy.feature, None) - - -def _normalize_feature(feature: schema_pb2.Feature, - schema: schema_pb2.Schema) -> None: - """Make each feature self-contained. - - If the feature references a global domain, copy the global domain locally. - Also do this for any child features. - - Note: the name of the domain is retained, so if we want to, we could attempt - to "unnormalize" the feature, recreating global domains. - - Args: - feature: feature to modify in place. - schema: schema containing any global domains. - """ - if feature.HasField("struct_domain"): - for x in feature.struct_domain.feature: - _normalize_feature(x, schema) - if feature.HasField("domain"): - for string_domain in schema.string_domain: - if string_domain.name == feature.domain: - feature.string_domain.CopyFrom(string_domain) - return - for int_domain in schema.int_domain: - if int_domain.name == feature.domain: - feature.int_domain.CopyFrom(int_domain) - return - for float_domain in schema.float_domain: - if float_domain.name == feature.domain: - feature.float_domain.CopyFrom(float_domain) - return - raise ValueError(f"Did not find domain {feature.domain} in schema {schema}") +def apply_schema( + expr: expression.Expression, schema: schema_pb2.Schema +) -> expression.Expression: + schema_copy = schema_pb2.Schema() + schema_copy.CopyFrom(schema) + for x in schema_copy.feature: + _normalize_feature(x, schema_copy) + return _SchemaExpression(expr, schema_copy.feature, None) -def _clean_feature(feature: schema_pb2.Feature) -> schema_pb2.Feature: - """Remove name and all children of a feature (if any exist), returning a copy. - - Args: - feature: input feature - - Returns: - cleaned feature - """ - copy = schema_pb2.Feature() - copy.CopyFrom(feature) - copy.ClearField("name") - if copy.HasField("struct_domain"): - del copy.struct_domain.feature[:] - return copy +def _normalize_feature(feature: schema_pb2.Feature, schema: schema_pb2.Schema) -> None: + """Make each feature self-contained. + If the feature references a global domain, copy the global domain locally. + Also do this for any child features. -def _apply_feature(original_child: expression.Expression, - feature: schema_pb2.Feature): - """Apply a feature to an expression. Feature should be "unclean".""" - feature_copy = [x for x in feature.struct_domain.feature - ] if feature.HasField("struct_domain") else [] - return _SchemaExpression(original_child, feature_copy, - _clean_feature(feature)) + Note: the name of the domain is retained, so if we want to, we could attempt + to "unnormalize" the feature, recreating global domains. + Args: + feature: feature to modify in place. + schema: schema containing any global domains. + """ + if feature.HasField("struct_domain"): + for x in feature.struct_domain.feature: + _normalize_feature(x, schema) + if feature.HasField("domain"): + for string_domain in schema.string_domain: + if string_domain.name == feature.domain: + feature.string_domain.CopyFrom(string_domain) + return + for int_domain in schema.int_domain: + if int_domain.name == feature.domain: + feature.int_domain.CopyFrom(int_domain) + return + for float_domain in schema.float_domain: + if float_domain.name == feature.domain: + feature.float_domain.CopyFrom(float_domain) + return + raise ValueError(f"Did not find domain {feature.domain} in schema {schema}") -class _SchemaExpression(expression.Expression, metaclass=abc.ABCMeta): - """An expression represents the application of a schema.""" - def __init__( - self, - original: expression.Expression, - child_features: Sequence[schema_pb2.Feature], - schema_feature: Optional[schema_pb2.Feature], - ): - """Create a new _SchemaExpression. +def _clean_feature(feature: schema_pb2.Feature) -> schema_pb2.Feature: + """Remove name and all children of a feature (if any exist), returning a copy. Args: - original: the original expression. - child_features: the uncleaned Feature protos for its children. - schema_feature: the optional cleaned feature for this node. + feature: input feature + + Returns: + cleaned feature """ - super().__init__( - original.is_repeated, - original.type, - schema_feature=schema_feature, - validate_step_format=original.validate_step_format, + copy = schema_pb2.Feature() + copy.CopyFrom(feature) + copy.ClearField("name") + if copy.HasField("struct_domain"): + del copy.struct_domain.feature[:] + return copy + + +def _apply_feature(original_child: expression.Expression, feature: schema_pb2.Feature): + """Apply a feature to an expression. Feature should be "unclean".""" + feature_copy = ( + [x for x in feature.struct_domain.feature] + if feature.HasField("struct_domain") + else [] ) - self._original = original - self._child_features = child_features - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._original] - - def calculate(self, source_tensors: Sequence[prensor.NodeTensor], # pytype: disable=signature-mismatch # overriding-parameter-count-checks - destinations: Sequence[expression.Expression], - options: calculate_options.Options) -> prensor.NodeTensor: - del destinations, options - [original_result] = source_tensors - return original_result - - def calculation_is_identity(self) -> bool: - return True - - def calculation_equal(self, expr: expression.Expression) -> bool: - return expr.calculation_is_identity() - - def _find_feature_proto(self, field_name: path.Step - ) -> Optional[schema_pb2.Feature]: - for feature in self._child_features: - if feature.name == field_name: - return feature - return None - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - original_child = self._original.get_child(field_name) - if original_child is None: - return None - feature_proto = self._find_feature_proto(field_name) - if feature_proto is None: - return original_child - return _apply_feature(original_child, feature_proto) - - def known_field_names(self) -> FrozenSet[path.Step]: - result = set(self._original.known_field_names()) - for feature_proto in self._child_features: - field_name = str(feature_proto.name) - associated_child = self.get_child(field_name) - if associated_child is not None: - result.add(field_name) - return frozenset(result) + return _SchemaExpression(original_child, feature_copy, _clean_feature(feature)) + + +class _SchemaExpression(expression.Expression, metaclass=abc.ABCMeta): + """An expression represents the application of a schema.""" + + def __init__( + self, + original: expression.Expression, + child_features: Sequence[schema_pb2.Feature], + schema_feature: Optional[schema_pb2.Feature], + ): + """Create a new _SchemaExpression. + + Args: + original: the original expression. + child_features: the uncleaned Feature protos for its children. + schema_feature: the optional cleaned feature for this node. + """ + super().__init__( + original.is_repeated, + original.type, + schema_feature=schema_feature, + validate_step_format=original.validate_step_format, + ) + self._original = original + self._child_features = child_features + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._original] + + def calculate( + self, + source_tensors: Sequence[ + prensor.NodeTensor + ], # pytype: disable=signature-mismatch # overriding-parameter-count-checks + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + ) -> prensor.NodeTensor: + del destinations, options + [original_result] = source_tensors + return original_result + + def calculation_is_identity(self) -> bool: + return True + + def calculation_equal(self, expr: expression.Expression) -> bool: + return expr.calculation_is_identity() + + def _find_feature_proto( + self, field_name: path.Step + ) -> Optional[schema_pb2.Feature]: + for feature in self._child_features: + if feature.name == field_name: + return feature + return None + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + original_child = self._original.get_child(field_name) + if original_child is None: + return None + feature_proto = self._find_feature_proto(field_name) + if feature_proto is None: + return original_child + return _apply_feature(original_child, feature_proto) + + def known_field_names(self) -> FrozenSet[path.Step]: + result = set(self._original.known_field_names()) + for feature_proto in self._child_features: + field_name = str(feature_proto.name) + associated_child = self.get_child(field_name) + if associated_child is not None: + result.add(field_name) + return frozenset(result) diff --git a/struct2tensor/expression_impl/apply_schema_test.py b/struct2tensor/expression_impl/apply_schema_test.py index c40fe23..922d679 100644 --- a/struct2tensor/expression_impl/apply_schema_test.py +++ b/struct2tensor/expression_impl/apply_schema_test.py @@ -23,80 +23,82 @@ def _features_as_map(feature_list): - return {feature.name: feature for feature in feature_list} + return {feature.name: feature for feature in feature_list} class SchemaUtilTest(absltest.TestCase): - - def test_apply_schema(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - input_schema = prensor_test_util.create_big_prensor_schema() - expected_input_schema = copy.deepcopy(input_schema) - expr2 = expr.apply_schema(input_schema) - foo_expr = expr2.get_descendant(path.Path(["foo"])) - self.assertIsNotNone(foo_expr) - foorepeated_expr = expr2.get_descendant(path.Path(["foorepeated"])) - self.assertIsNotNone(foorepeated_expr) - doc_bar_expr = expr2.get_descendant(path.Path(["doc", "bar"])) - self.assertIsNotNone(doc_bar_expr) - - # Test that a domain already in the feature is maintained. - self.assertEqual(foo_expr.schema_feature.int_domain.max, 10) - # Test that an int_domain specified at the schema level is inserted - # correctly. - self.assertEqual(foorepeated_expr.schema_feature.int_domain.max, 10) - - # Test that a string_domain specified at the schema level is inserted - # correctly. - self.assertEqual(doc_bar_expr.schema_feature.string_domain.value[0], "a") - self.assertIsNotNone(expr2.get_descendant(path.Path(["user", "friends"]))) - self.assertIsNotNone(expr2.get_descendant(path.Path(["doc", "keep_me"]))) - - # Ensure the input schema wasn't mutated. - self.assertEqual(input_schema, expected_input_schema) - - def test_apply_empty_schema(self): - """Test that applying an empty schema does not filter out paths.""" - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - expr2 = expr.apply_schema(schema_pb2.Schema()) - foo_expr = expr2.get_descendant(path.Path(["foo"])) - self.assertIsNotNone(foo_expr) - foorepeated_expr = expr2.get_descendant(path.Path(["foorepeated"])) - self.assertIsNotNone(foorepeated_expr) - doc_bar_expr = expr2.get_descendant(path.Path(["doc", "bar"])) - self.assertIsNotNone(doc_bar_expr) - known_field_names = expr2.known_field_names() - self.assertIn("doc", known_field_names) - self.assertIn("foo", known_field_names) - self.assertIn("foorepeated", known_field_names) - self.assertIn("user", known_field_names) - - def test_get_schema(self): - """Integration test between get_schema and apply_schema.""" - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - expr2 = expr.apply_schema(prensor_test_util.create_big_prensor_schema()) - schema_result = expr2.get_schema() - feature_map = _features_as_map(schema_result.feature) - self.assertIn("foo", feature_map) - # Test that a domain already in the feature is maintained. - self.assertEqual(feature_map["foo"].int_domain.max, 10) - # Test that an int_domain specified at the schema level is inserted - # correctly. - self.assertEqual(feature_map["foorepeated"].int_domain.max, 10) - - def test_get_schema_lenient_names(self): - """Test that apply_schema preserves origin expression leniency.""" - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor_with_lenient_field_names(), - validate_step_format=False, - ) - expr2 = expr.apply_schema(schema_pb2.Schema()) - self.assertFalse(expr2.validate_step_format) - self.assertLen(expr2.get_known_descendants(), 6) + def test_apply_schema(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + input_schema = prensor_test_util.create_big_prensor_schema() + expected_input_schema = copy.deepcopy(input_schema) + expr2 = expr.apply_schema(input_schema) + foo_expr = expr2.get_descendant(path.Path(["foo"])) + self.assertIsNotNone(foo_expr) + foorepeated_expr = expr2.get_descendant(path.Path(["foorepeated"])) + self.assertIsNotNone(foorepeated_expr) + doc_bar_expr = expr2.get_descendant(path.Path(["doc", "bar"])) + self.assertIsNotNone(doc_bar_expr) + + # Test that a domain already in the feature is maintained. + self.assertEqual(foo_expr.schema_feature.int_domain.max, 10) + # Test that an int_domain specified at the schema level is inserted + # correctly. + self.assertEqual(foorepeated_expr.schema_feature.int_domain.max, 10) + + # Test that a string_domain specified at the schema level is inserted + # correctly. + self.assertEqual(doc_bar_expr.schema_feature.string_domain.value[0], "a") + self.assertIsNotNone(expr2.get_descendant(path.Path(["user", "friends"]))) + self.assertIsNotNone(expr2.get_descendant(path.Path(["doc", "keep_me"]))) + + # Ensure the input schema wasn't mutated. + self.assertEqual(input_schema, expected_input_schema) + + def test_apply_empty_schema(self): + """Test that applying an empty schema does not filter out paths.""" + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + expr2 = expr.apply_schema(schema_pb2.Schema()) + foo_expr = expr2.get_descendant(path.Path(["foo"])) + self.assertIsNotNone(foo_expr) + foorepeated_expr = expr2.get_descendant(path.Path(["foorepeated"])) + self.assertIsNotNone(foorepeated_expr) + doc_bar_expr = expr2.get_descendant(path.Path(["doc", "bar"])) + self.assertIsNotNone(doc_bar_expr) + known_field_names = expr2.known_field_names() + self.assertIn("doc", known_field_names) + self.assertIn("foo", known_field_names) + self.assertIn("foorepeated", known_field_names) + self.assertIn("user", known_field_names) + + def test_get_schema(self): + """Integration test between get_schema and apply_schema.""" + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + expr2 = expr.apply_schema(prensor_test_util.create_big_prensor_schema()) + schema_result = expr2.get_schema() + feature_map = _features_as_map(schema_result.feature) + self.assertIn("foo", feature_map) + # Test that a domain already in the feature is maintained. + self.assertEqual(feature_map["foo"].int_domain.max, 10) + # Test that an int_domain specified at the schema level is inserted + # correctly. + self.assertEqual(feature_map["foorepeated"].int_domain.max, 10) + + def test_get_schema_lenient_names(self): + """Test that apply_schema preserves origin expression leniency.""" + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor_with_lenient_field_names(), + validate_step_format=False, + ) + expr2 = expr.apply_schema(schema_pb2.Schema()) + self.assertFalse(expr2.validate_step_format) + self.assertLen(expr2.get_known_descendants(), 6) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/broadcast.py b/struct2tensor/expression_impl/broadcast.py index 1164212..1262d0a 100644 --- a/struct2tensor/expression_impl/broadcast.py +++ b/struct2tensor/expression_impl/broadcast.py @@ -89,52 +89,55 @@ class _BroadcastExpression(expression.Leaf): - """A broadcast field.""" - - def __init__(self, origin: expression.Expression, - sibling: expression.Expression): - super().__init__(origin.is_repeated, origin.type) - if origin.type is None: - raise ValueError("Can only broadcast a field") - self._origin = origin - self._sibling = sibling - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._origin, self._sibling] - - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - [origin_value, sibling_value] = sources - if not isinstance(origin_value, prensor.LeafNodeTensor): - raise ValueError("origin not a LeafNodeTensor") - if not isinstance(sibling_value, prensor.ChildNodeTensor): - raise ValueError("sibling value is not a ChildNodeTensor") - # For each i, for each v, if there exist exactly n values j such that: - # sibling_value.parent_index[i]==origin_value.parent_index[j] - # then there exists exactly n values k such that: - # new_parent_index[k] = i - # new_values[k] = origin_value.values[j] - # (Ordering is also preserved). - [broadcasted_to_sibling_index, index_to_values - ] = struct2tensor_ops.equi_join_indices(sibling_value.parent_index, - origin_value.parent_index) - new_values = tf.gather(origin_value.values, index_to_values) - return prensor.LeafNodeTensor(broadcasted_to_sibling_index, new_values, - self.is_repeated) - - def calculation_is_identity(self) -> bool: - return False - - def calculation_equal(self, expr: expression.Expression) -> bool: - return isinstance(expr, _BroadcastExpression) + """A broadcast field.""" + + def __init__(self, origin: expression.Expression, sibling: expression.Expression): + super().__init__(origin.is_repeated, origin.type) + if origin.type is None: + raise ValueError("Can only broadcast a field") + self._origin = origin + self._sibling = sibling + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._origin, self._sibling] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + [origin_value, sibling_value] = sources + if not isinstance(origin_value, prensor.LeafNodeTensor): + raise ValueError("origin not a LeafNodeTensor") + if not isinstance(sibling_value, prensor.ChildNodeTensor): + raise ValueError("sibling value is not a ChildNodeTensor") + # For each i, for each v, if there exist exactly n values j such that: + # sibling_value.parent_index[i]==origin_value.parent_index[j] + # then there exists exactly n values k such that: + # new_parent_index[k] = i + # new_values[k] = origin_value.values[j] + # (Ordering is also preserved). + [broadcasted_to_sibling_index, index_to_values] = ( + struct2tensor_ops.equi_join_indices( + sibling_value.parent_index, origin_value.parent_index + ) + ) + new_values = tf.gather(origin_value.values, index_to_values) + return prensor.LeafNodeTensor( + broadcasted_to_sibling_index, new_values, self.is_repeated + ) + + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + return isinstance(expr, _BroadcastExpression) class _RecalculateExpression(expression.Expression): - r"""Expression for recalculating a broadcasted subtree's parent indices. + r"""Expression for recalculating a broadcasted subtree's parent indices. This is needed because when a subtree is broadcasted, the nested fields could be duplicated. Their parent indices need to be updated. @@ -169,160 +172,172 @@ class _RecalculateExpression(expression.Expression): ChildNodeTensor. """ - def __init__(self, origin: expression.Expression, - parent: expression.Expression): - super().__init__( - origin.is_repeated, - origin.type, - validate_step_format=origin.validate_step_format, - ) - self._origin = origin - self._parent = parent - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._origin, self._parent] - - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - [origin_value, parent_value] = sources - - # We should never be recalculating a RootNodeTensor. - assert not isinstance(origin_value, prensor.RootNodeTensor), origin_value - - # The parent cannot be a LeafNodeTensor or RootNodeTensor, because - # a) a leaf node cannot have a submessage - # b) you cannot broadcast into a root - assert isinstance(parent_value, prensor.ChildNodeTensor), parent_value - - # We use equi_join_any_indices on the parent's `index_to_value` because it - # represents which child nodes were duplicated. Thus, which origin values - # also need to be duplicated. - [broadcasted_to_sibling_index, index_to_values - ] = struct2tensor_ops.equi_join_any_indices(parent_value.index_to_value, - origin_value.parent_index) - - if isinstance(origin_value, prensor.LeafNodeTensor): - new_values = tf.gather(origin_value.values, index_to_values) - return prensor.LeafNodeTensor(broadcasted_to_sibling_index, new_values, - self.is_repeated) - return prensor.ChildNodeTensor(broadcasted_to_sibling_index, - self.is_repeated, index_to_values) - - def calculation_is_identity(self) -> bool: - return False - - def calculation_equal(self, expr: expression.Expression) -> bool: - return isinstance(expr, _RecalculateExpression) - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - """Gets the child expression. - - All children of a _RecalculateExpression should also be - _RecalculateExpressions, because we all fields in the subtree need to be - recalculated. This child is inherited from the origin's child. - - Args: - field_name: the name of the child. - - Returns: - an expression if a child exists, otherwise None. - """ - original_child = self._origin.get_child(field_name) - if original_child is None: - return None - return _RecalculateExpression(original_child, self) - - def known_field_names(self) -> FrozenSet[path.Step]: - return self._origin.known_field_names() + def __init__(self, origin: expression.Expression, parent: expression.Expression): + super().__init__( + origin.is_repeated, + origin.type, + validate_step_format=origin.validate_step_format, + ) + self._origin = origin + self._parent = parent + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._origin, self._parent] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + [origin_value, parent_value] = sources + + # We should never be recalculating a RootNodeTensor. + assert not isinstance(origin_value, prensor.RootNodeTensor), origin_value + + # The parent cannot be a LeafNodeTensor or RootNodeTensor, because + # a) a leaf node cannot have a submessage + # b) you cannot broadcast into a root + assert isinstance(parent_value, prensor.ChildNodeTensor), parent_value + + # We use equi_join_any_indices on the parent's `index_to_value` because it + # represents which child nodes were duplicated. Thus, which origin values + # also need to be duplicated. + [broadcasted_to_sibling_index, index_to_values] = ( + struct2tensor_ops.equi_join_any_indices( + parent_value.index_to_value, origin_value.parent_index + ) + ) + + if isinstance(origin_value, prensor.LeafNodeTensor): + new_values = tf.gather(origin_value.values, index_to_values) + return prensor.LeafNodeTensor( + broadcasted_to_sibling_index, new_values, self.is_repeated + ) + return prensor.ChildNodeTensor( + broadcasted_to_sibling_index, self.is_repeated, index_to_values + ) + + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + return isinstance(expr, _RecalculateExpression) + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + """Gets the child expression. + + All children of a _RecalculateExpression should also be + _RecalculateExpressions, because we all fields in the subtree need to be + recalculated. This child is inherited from the origin's child. + + Args: + field_name: the name of the child. + + Returns: + an expression if a child exists, otherwise None. + """ + original_child = self._origin.get_child(field_name) + if original_child is None: + return None + return _RecalculateExpression(original_child, self) + + def known_field_names(self) -> FrozenSet[path.Step]: + return self._origin.known_field_names() class _BroadcastChildExpression(expression.Expression): - """A broadcast expression for subtree. + """A broadcast expression for subtree. - This expression represents the root of the subtree that is broadcasted. - It wraps all of its children as `_RecalculateExpression`s. - """ + This expression represents the root of the subtree that is broadcasted. + It wraps all of its children as `_RecalculateExpression`s. + """ - def __init__(self, origin: expression.Expression, - sibling: expression.Expression): - super().__init__( - origin.is_repeated, - origin.type, - validate_step_format=origin.validate_step_format, - ) - self._origin = origin - self._sibling = sibling - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._origin, self._sibling] - - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - [origin_value, sibling_value] = sources - if not isinstance(origin_value, prensor.ChildNodeTensor): - raise ValueError("origin not a ChildNodeTensor") - if not isinstance(sibling_value, prensor.ChildNodeTensor): - raise ValueError("sibling value is not a ChildNodeTensor") - - [broadcasted_to_sibling_index, index_to_values - ] = struct2tensor_ops.equi_join_any_indices(sibling_value.parent_index, - origin_value.parent_index) - return prensor.ChildNodeTensor( - broadcasted_to_sibling_index, - self.is_repeated, - index_to_value=index_to_values) - - def calculation_is_identity(self) -> bool: - return False - - def calculation_equal(self, expr: expression.Expression) -> bool: - return isinstance(expr, _BroadcastChildExpression) - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - return _RecalculateExpression(self._origin.get_child(field_name), self) - - def known_field_names(self) -> FrozenSet[path.Step]: - return self._origin.known_field_names() + def __init__(self, origin: expression.Expression, sibling: expression.Expression): + super().__init__( + origin.is_repeated, + origin.type, + validate_step_format=origin.validate_step_format, + ) + self._origin = origin + self._sibling = sibling + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._origin, self._sibling] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + [origin_value, sibling_value] = sources + if not isinstance(origin_value, prensor.ChildNodeTensor): + raise ValueError("origin not a ChildNodeTensor") + if not isinstance(sibling_value, prensor.ChildNodeTensor): + raise ValueError("sibling value is not a ChildNodeTensor") + + [broadcasted_to_sibling_index, index_to_values] = ( + struct2tensor_ops.equi_join_any_indices( + sibling_value.parent_index, origin_value.parent_index + ) + ) + return prensor.ChildNodeTensor( + broadcasted_to_sibling_index, + self.is_repeated, + index_to_value=index_to_values, + ) + + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + return isinstance(expr, _BroadcastChildExpression) + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + return _RecalculateExpression(self._origin.get_child(field_name), self) + + def known_field_names(self) -> FrozenSet[path.Step]: + return self._origin.known_field_names() def _broadcast_impl( - root: expression.Expression, origin: path.Path, sibling: path.Step, - new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: - """Broadcasts origin to sibling for an expression.""" - sibling_path = origin.get_parent().get_child(sibling) - - origin_expression = root.get_descendant_or_error(origin) - - broadcast_expression_factory = ( - _BroadcastExpression - if origin_expression.is_leaf else _BroadcastChildExpression) + root: expression.Expression, + origin: path.Path, + sibling: path.Step, + new_field_name: path.Step, +) -> Tuple[expression.Expression, path.Path]: + """Broadcasts origin to sibling for an expression.""" + sibling_path = origin.get_parent().get_child(sibling) + + origin_expression = root.get_descendant_or_error(origin) + + broadcast_expression_factory = ( + _BroadcastExpression if origin_expression.is_leaf else _BroadcastChildExpression + ) - new_expr = broadcast_expression_factory( - origin_expression, - root.get_descendant_or_error(origin.get_parent().get_child(sibling))) - new_path = sibling_path.get_child(new_field_name) - result = expression_add.add_paths(root, {new_path: new_expr}) + new_expr = broadcast_expression_factory( + origin_expression, + root.get_descendant_or_error(origin.get_parent().get_child(sibling)), + ) + new_path = sibling_path.get_child(new_field_name) + result = expression_add.add_paths(root, {new_path: new_expr}) - return result, new_path + return result, new_path def broadcast_anonymous( - root: expression.Expression, origin: path.Path, - sibling: path.Step) -> Tuple[expression.Expression, path.Path]: - return _broadcast_impl(root, origin, sibling, path.get_anonymous_field()) - - -def broadcast(root: expression.Expression, origin: path.Path, - sibling_name: path.Step, - new_field_name: path.Step) -> expression.Expression: - return _broadcast_impl(root, origin, sibling_name, new_field_name)[0] + root: expression.Expression, origin: path.Path, sibling: path.Step +) -> Tuple[expression.Expression, path.Path]: + return _broadcast_impl(root, origin, sibling, path.get_anonymous_field()) + + +def broadcast( + root: expression.Expression, + origin: path.Path, + sibling_name: path.Step, + new_field_name: path.Step, +) -> expression.Expression: + return _broadcast_impl(root, origin, sibling_name, new_field_name)[0] diff --git a/struct2tensor/expression_impl/broadcast_test.py b/struct2tensor/expression_impl/broadcast_test.py index 081cf33..3e8d264 100644 --- a/struct2tensor/expression_impl/broadcast_test.py +++ b/struct2tensor/expression_impl/broadcast_test.py @@ -17,7 +17,7 @@ import tensorflow as tf from absl.testing import absltest from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor import create_expression, path @@ -26,230 +26,230 @@ class BroadcastTest(tf.test.TestCase): - - def test_broadcast_anonymous(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - new_root, p = broadcast.broadcast_anonymous(expr, path.Path(["foo"]), - "user") - new_field = new_root.get_descendant_or_error(p) - self.assertFalse(new_field.is_repeated) - self.assertEqual(new_field.type, tf.int32) - self.assertTrue(new_field.is_leaf) - self.assertTrue(new_field.calculation_equal(new_field)) - self.assertFalse(new_field.calculation_equal(expr)) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertEqual(leaf_node.values.dtype, tf.int32) - self.assertEqual(new_field.known_field_names(), frozenset()) - - sources = new_field.get_source_expressions() - self.assertLen(sources, 2) - self.assertIs(expr.get_child("foo"), sources[0]) - self.assertIs(expr.get_child("user"), sources[1]) - - def test_broadcast(self): - """Tests broadcast.broadcast(...), and indirectly tests set_path.""" - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - new_root = broadcast.broadcast(expr, path.Path(["foo"]), "user", - "new_field") - new_field = new_root.get_child("user").get_child("new_field") - self.assertIsNotNone(new_field) - self.assertFalse(new_field.is_repeated) - self.assertEqual(new_field.type, tf.int32) - self.assertTrue(new_field.is_leaf) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertEqual(leaf_node.values.dtype, tf.int32) - self.assertEqual(new_field.known_field_names(), frozenset()) - - def test_broadcast_substructure(self): - """Tests broadcast of a submessage. - - The result of broadcasting `user` into `doc` looks like: - { - foo: 9, - foorepeated: [9], - doc: [{bar:["a"], keep_me:False, new_user: [{friends:["a"]}]}], - user: [{friends:["a"]}] - }, - { - foo: 8, - foorepeated: [8, 7], - doc: [ + def test_broadcast_anonymous(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + new_root, p = broadcast.broadcast_anonymous(expr, path.Path(["foo"]), "user") + new_field = new_root.get_descendant_or_error(p) + self.assertFalse(new_field.is_repeated) + self.assertEqual(new_field.type, tf.int32) + self.assertTrue(new_field.is_leaf) + self.assertTrue(new_field.calculation_equal(new_field)) + self.assertFalse(new_field.calculation_equal(expr)) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertEqual(leaf_node.values.dtype, tf.int32) + self.assertEqual(new_field.known_field_names(), frozenset()) + + sources = new_field.get_source_expressions() + self.assertLen(sources, 2) + self.assertIs(expr.get_child("foo"), sources[0]) + self.assertIs(expr.get_child("user"), sources[1]) + + def test_broadcast(self): + """Tests broadcast.broadcast(...), and indirectly tests set_path.""" + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + new_root = broadcast.broadcast(expr, path.Path(["foo"]), "user", "new_field") + new_field = new_root.get_child("user").get_child("new_field") + self.assertIsNotNone(new_field) + self.assertFalse(new_field.is_repeated) + self.assertEqual(new_field.type, tf.int32) + self.assertTrue(new_field.is_leaf) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertEqual(leaf_node.values.dtype, tf.int32) + self.assertEqual(new_field.known_field_names(), frozenset()) + + def test_broadcast_substructure(self): + """Tests broadcast of a submessage. + + The result of broadcasting `user` into `doc` looks like: + { + foo: 9, + foorepeated: [9], + doc: [{bar:["a"], keep_me:False, new_user: [{friends:["a"]}]}], + user: [{friends:["a"]}] + }, { - bar: ["b","c"], - keep_me: True, - new_user: [{friends:["b", "c"]},{friends:["d"]}] + foo: 8, + foorepeated: [8, 7], + doc: [ + { + bar: ["b","c"], + keep_me: True, + new_user: [{friends:["b", "c"]},{friends:["d"]}] + }, + { + bar: ["d"], + new_user: [{friends:["b", "c"]},{friends:["d"]}] + } + ], + user: [{friends:["b", "c"]},{friends:["d"]}], }, { - bar: ["d"], - new_user: [{friends:["b", "c"]},{friends:["d"]}] + foo: 7, + foorepeated: [6], + user: [{friends:["e"]}] } - ], - user: [{friends:["b", "c"]},{friends:["d"]}], - }, - { - foo: 7, - foorepeated: [6], - user: [{friends:["e"]}] - } - """ - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - new_root = broadcast.broadcast(expr, path.Path(["user"]), "doc", "new_user") - new_user = new_root.get_child("doc").get_child("new_user") - self.assertIsNotNone(new_user) - self.assertTrue(new_user.is_repeated) - self.assertIsNone(new_user.type) - self.assertFalse(new_user.is_leaf) - - new_user_node = expression_test_util.calculate_value_slowly(new_user) - self.assertAllEqual(new_user_node.parent_index, [0, 1, 1, 2, 2]) - self.assertAllEqual(new_user_node.index_to_value, [0, 1, 2, 1, 2]) - - new_friends = new_user.get_child("friends") - self.assertIsNotNone(new_friends) - self.assertTrue(new_friends.is_repeated) - self.assertEqual(new_friends.type, tf.string) - self.assertTrue(new_friends.is_leaf) - - new_friends_node = expression_test_util.calculate_value_slowly(new_friends) - self.assertEqual(new_friends_node.values.dtype, tf.string) - self.assertAllEqual(new_friends_node.values, - ["a", "b", "c", "d", "b", "c", "d"]) - self.assertAllEqual(new_friends_node.parent_index, [0, 1, 1, 2, 3, 3, 4]) - - def test_broadcast_substructure_deep(self): - """Tests broadcast of a submessage. - - The result of broadcasting `event` into `user` looks like: - { - foo: 9, - foorepeated: [9], - user: [{ - friends: ["a"], - new_event: [{ - doc:[{ - bar: ["a"], - keep_me:False - }] - }] - }], - event: [{doc:[{bar:["a"], keep_me:False}]}] - }, - { - foo: 8, - foorepeated: [8,7], - user: [{ - friends: ["b", "c"], - new_event: [{ - doc:[{ - bar: ["b","c"], - keep_me:True - }, - { - bar:["d"] - }] - }] - }, - { - friends: ["d"], - new_event: [{ - doc:[{ - bar: ["b","c"], - keep_me: True + """ + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + new_root = broadcast.broadcast(expr, path.Path(["user"]), "doc", "new_user") + new_user = new_root.get_child("doc").get_child("new_user") + self.assertIsNotNone(new_user) + self.assertTrue(new_user.is_repeated) + self.assertIsNone(new_user.type) + self.assertFalse(new_user.is_leaf) + + new_user_node = expression_test_util.calculate_value_slowly(new_user) + self.assertAllEqual(new_user_node.parent_index, [0, 1, 1, 2, 2]) + self.assertAllEqual(new_user_node.index_to_value, [0, 1, 2, 1, 2]) + + new_friends = new_user.get_child("friends") + self.assertIsNotNone(new_friends) + self.assertTrue(new_friends.is_repeated) + self.assertEqual(new_friends.type, tf.string) + self.assertTrue(new_friends.is_leaf) + + new_friends_node = expression_test_util.calculate_value_slowly(new_friends) + self.assertEqual(new_friends_node.values.dtype, tf.string) + self.assertAllEqual( + new_friends_node.values, ["a", "b", "c", "d", "b", "c", "d"] + ) + self.assertAllEqual(new_friends_node.parent_index, [0, 1, 1, 2, 3, 3, 4]) + + def test_broadcast_substructure_deep(self): + """Tests broadcast of a submessage. + + The result of broadcasting `event` into `user` looks like: + { + foo: 9, + foorepeated: [9], + user: [{ + friends: ["a"], + new_event: [{ + doc:[{ + bar: ["a"], + keep_me:False + }] + }] + }], + event: [{doc:[{bar:["a"], keep_me:False}]}] + }, + { + foo: 8, + foorepeated: [8,7], + user: [{ + friends: ["b", "c"], + new_event: [{ + doc:[{ + bar: ["b","c"], + keep_me:True + }, + { + bar:["d"] + }] + }] }, { - bar: ["d"] - }] - }] - }], - event: [{doc:[{bar:["b","c"], keep_me:True},{bar:["d"]}]}] - }, - { - foo:7, - foorepeated: [6], - user: [{ - friends:["e"], - new_event: [{}] - }], - event: [{}] - } - """ - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_deep_prensor()) - new_root = broadcast.broadcast(expr, path.Path(["event"]), "user", - "new_event") - new_event = new_root.get_child("user").get_child("new_event") - self.assertIsNotNone(new_event) - self.assertTrue(new_event.is_repeated) - self.assertIsNone(new_event.type) - self.assertFalse(new_event.is_leaf) - - new_event_node = expression_test_util.calculate_value_slowly(new_event) - self.assertAllEqual(new_event_node.parent_index, [0, 1, 2, 3]) - self.assertAllEqual(new_event_node.index_to_value, [0, 1, 1, 2]) - - new_doc = new_event.get_child("doc") - self.assertIsNotNone(new_doc) - self.assertTrue(new_doc.is_repeated) - self.assertIsNone(new_doc.type) - self.assertFalse(new_doc.is_leaf) - - new_doc_node = expression_test_util.calculate_value_slowly(new_doc) - self.assertAllEqual(new_doc_node.parent_index, [0, 1, 1, 2, 2]) - self.assertAllEqual(new_doc_node.index_to_value, [0, 1, 2, 1, 2]) - - new_bar = new_doc.get_child("bar") - self.assertIsNotNone(new_bar) - self.assertTrue(new_doc.is_repeated) - self.assertEqual(new_bar.type, tf.string) - self.assertTrue(new_bar.is_leaf) - - new_bar_node = expression_test_util.calculate_value_slowly(new_bar) - self.assertAllEqual(new_bar_node.values, - ["a", "b", "c", "d", "b", "c", "d"]) - self.assertAllEqual(new_bar_node.parent_index, [0, 1, 1, 2, 3, 3, 4]) - - new_keep_me = new_doc.get_child("keep_me") - self.assertIsNotNone(new_keep_me) - self.assertFalse(new_keep_me.is_repeated) - self.assertEqual(new_keep_me.type, tf.bool) - self.assertTrue(new_keep_me.is_leaf) - - new_keep_me_node = expression_test_util.calculate_value_slowly(new_keep_me) - self.assertAllEqual(new_keep_me_node.values, - [False, True, True]) - self.assertAllEqual(new_keep_me_node.parent_index, [0, 1, 3]) - - def test_broadcast_with_lenient_names(self): - """Tests broadcast with lenient step names.""" - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor_with_lenient_field_names(), - validate_step_format=False, - ) - new_root, new_path = broadcast.broadcast_anonymous( - expr, path.Path(["doc"], validate_step_format=False), "user" - ) - new_field = new_root.get_descendant_or_error(new_path) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 2, 2]) + friends: ["d"], + new_event: [{ + doc:[{ + bar: ["b","c"], + keep_me: True + }, + { + bar: ["d"] + }] + }] + }], + event: [{doc:[{bar:["b","c"], keep_me:True},{bar:["d"]}]}] + }, + { + foo:7, + foorepeated: [6], + user: [{ + friends:["e"], + new_event: [{}] + }], + event: [{}] + } + """ + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_deep_prensor() + ) + new_root = broadcast.broadcast(expr, path.Path(["event"]), "user", "new_event") + new_event = new_root.get_child("user").get_child("new_event") + self.assertIsNotNone(new_event) + self.assertTrue(new_event.is_repeated) + self.assertIsNone(new_event.type) + self.assertFalse(new_event.is_leaf) + + new_event_node = expression_test_util.calculate_value_slowly(new_event) + self.assertAllEqual(new_event_node.parent_index, [0, 1, 2, 3]) + self.assertAllEqual(new_event_node.index_to_value, [0, 1, 1, 2]) + + new_doc = new_event.get_child("doc") + self.assertIsNotNone(new_doc) + self.assertTrue(new_doc.is_repeated) + self.assertIsNone(new_doc.type) + self.assertFalse(new_doc.is_leaf) + + new_doc_node = expression_test_util.calculate_value_slowly(new_doc) + self.assertAllEqual(new_doc_node.parent_index, [0, 1, 1, 2, 2]) + self.assertAllEqual(new_doc_node.index_to_value, [0, 1, 2, 1, 2]) + + new_bar = new_doc.get_child("bar") + self.assertIsNotNone(new_bar) + self.assertTrue(new_doc.is_repeated) + self.assertEqual(new_bar.type, tf.string) + self.assertTrue(new_bar.is_leaf) + + new_bar_node = expression_test_util.calculate_value_slowly(new_bar) + self.assertAllEqual(new_bar_node.values, ["a", "b", "c", "d", "b", "c", "d"]) + self.assertAllEqual(new_bar_node.parent_index, [0, 1, 1, 2, 3, 3, 4]) + + new_keep_me = new_doc.get_child("keep_me") + self.assertIsNotNone(new_keep_me) + self.assertFalse(new_keep_me.is_repeated) + self.assertEqual(new_keep_me.type, tf.bool) + self.assertTrue(new_keep_me.is_leaf) + + new_keep_me_node = expression_test_util.calculate_value_slowly(new_keep_me) + self.assertAllEqual(new_keep_me_node.values, [False, True, True]) + self.assertAllEqual(new_keep_me_node.parent_index, [0, 1, 3]) + + def test_broadcast_with_lenient_names(self): + """Tests broadcast with lenient step names.""" + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor_with_lenient_field_names(), + validate_step_format=False, + ) + new_root, new_path = broadcast.broadcast_anonymous( + expr, path.Path(["doc"], validate_step_format=False), "user" + ) + new_field = new_root.get_descendant_or_error(new_path) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 2, 2]) @test_util.run_all_in_graph_and_eager_modes class BroadcastValuesTest(tf.test.TestCase): - - def test_broadcast_and_calculate(self): - """Tests get_sparse_tensors on a deep tree.""" - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - new_root, new_path = broadcast.broadcast_anonymous(expr, path.Path(["foo"]), - "user") - new_field = new_root.get_descendant_or_error(new_path) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertAllEqual(leaf_node.parent_index, [0, 1, 2, 3]) - self.assertAllEqual(leaf_node.values, [9, 8, 8, 7]) + def test_broadcast_and_calculate(self): + """Tests get_sparse_tensors on a deep tree.""" + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + new_root, new_path = broadcast.broadcast_anonymous( + expr, path.Path(["foo"]), "user" + ) + new_field = new_root.get_descendant_or_error(new_path) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertAllEqual(leaf_node.parent_index, [0, 1, 2, 3]) + self.assertAllEqual(leaf_node.values, [9, 8, 8, 7]) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/depth_limit.py b/struct2tensor/expression_impl/depth_limit.py index 34108c3..487c0d0 100644 --- a/struct2tensor/expression_impl/depth_limit.py +++ b/struct2tensor/expression_impl/depth_limit.py @@ -44,53 +44,52 @@ from struct2tensor import calculate_options, expression, path, prensor -def limit_depth(expr: expression.Expression, - depth_limit: int) -> expression.Expression: - """Limit the depth to nodes k steps from expr.""" - return _DepthLimitExpression(expr, depth_limit) +def limit_depth(expr: expression.Expression, depth_limit: int) -> expression.Expression: + """Limit the depth to nodes k steps from expr.""" + return _DepthLimitExpression(expr, depth_limit) class _DepthLimitExpression(expression.Expression): - """Project all subfields of an expression.""" - - def __init__(self, origin: expression.Expression, depth_limit: int): - super().__init__( - origin.is_repeated, - origin.type, - validate_step_format=origin.validate_step_format, - ) - self._origin = origin - self._depth_limit = depth_limit - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._origin] - - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - if len(sources) != 1: - raise ValueError("Expected one source.") - return sources[0] - - def calculation_is_identity(self) -> bool: - return True - - def calculation_equal(self, expr: expression.Expression) -> bool: - return expr.calculation_is_identity() - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - if self._depth_limit == 0: - return None - origin_child = self._origin.get_child(field_name) - if origin_child is None: - return None - return _DepthLimitExpression(origin_child, self._depth_limit - 1) - - def known_field_names(self) -> FrozenSet[path.Step]: - if self._depth_limit == 0: - return frozenset() - return self._origin.known_field_names() + """Project all subfields of an expression.""" + + def __init__(self, origin: expression.Expression, depth_limit: int): + super().__init__( + origin.is_repeated, + origin.type, + validate_step_format=origin.validate_step_format, + ) + self._origin = origin + self._depth_limit = depth_limit + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._origin] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + if len(sources) != 1: + raise ValueError("Expected one source.") + return sources[0] + + def calculation_is_identity(self) -> bool: + return True + + def calculation_equal(self, expr: expression.Expression) -> bool: + return expr.calculation_is_identity() + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + if self._depth_limit == 0: + return None + origin_child = self._origin.get_child(field_name) + if origin_child is None: + return None + return _DepthLimitExpression(origin_child, self._depth_limit - 1) + + def known_field_names(self) -> FrozenSet[path.Step]: + if self._depth_limit == 0: + return frozenset() + return self._origin.known_field_names() diff --git a/struct2tensor/expression_impl/depth_limit_test.py b/struct2tensor/expression_impl/depth_limit_test.py index a7472f9..7924acd 100644 --- a/struct2tensor/expression_impl/depth_limit_test.py +++ b/struct2tensor/expression_impl/depth_limit_test.py @@ -22,80 +22,80 @@ class DepthLimitTest(absltest.TestCase): - - def test_depth_limit_1(self): - """Tests depth_limit with a limit of 1. - - Starting with a prensor expression representing: - {foo:9, foorepeated:[9], user:[{friends:["a"]}], - event:{doc:[{bar:["a"], keep_me:False}]}} - {foo:8, foorepeated:[8,7], - event:{doc:[{bar:["b","c"], keep_me:True},{bar:["d"]}]}, - user:[{friends:["b", "c"]}, {friends:["d"]}]} - {foo:7, foorepeated:[6], user:[friends:["e"]], event:{}} - - After depth_limit.limit_depth(expr, 1), you lose event.doc, user.friends, - event.doc.bar, and event.doc.keep_me. - - {foo:9, foorepeated:[9], user:[{}], event:{}} - {foo:8, foorepeated:[8,7], event:{},user:[{}, {}]} - {foo:7, foorepeated:[6], user:[{}], event:{}} - - """ - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_deep_prensor()) - new_root = depth_limit.limit_depth(expr, 1) - - self.assertIsNone( - new_root.get_descendant(path.Path(["event", "doc", "bar"]))) - self.assertIsNone( - new_root.get_descendant(path.Path(["event", "doc", "keep_me"]))) - - self.assertIsNone(new_root.get_descendant(path.Path(["user", "friends"]))) - self.assertIsNone(new_root.get_descendant(path.Path(["event", "doc"]))) - - self.assertIsNotNone(new_root.get_descendant(path.Path(["foo"]))) - self.assertIsNotNone(new_root.get_descendant(path.Path(["foorepeated"]))) - self.assertIsNotNone(new_root.get_descendant(path.Path(["user"]))) - self.assertIsNotNone(new_root.get_descendant(path.Path(["event"]))) - - def test_depth_limit_2(self): - """Tests depth_limit with a limit of 2. - - Starting with a prensor expression representing: - {foo:9, foorepeated:[9], user:[{friends:["a"]}], - event:{doc:[{bar:["a"], keep_me:False}]}} - {foo:8, foorepeated:[8,7], - event:{doc:[{bar:["b","c"], keep_me:True},{bar:["d"]}]}, - user:[{friends:["b", "c"]}, {friends:["d"]}]} - {foo:7, foorepeated:[6], user:[friends:["e"]], event:{}} - - After depth_limit.limit_depth(expr, 2), you lose event.doc.bar - and event.doc.keep_me: - - {foo:9, foorepeated:[9], user:[{friends:["a"]}], event:{doc:[{}]}} - {foo:8, foorepeated:[8,7], event:{doc:[{},{}]}, - user:[{friends:["b", "c"]}, {friends:["d"]}]} - {foo:7, foorepeated:[6], user:[friends:["e"]], event:{}} - - """ - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_deep_prensor(), validate_step_format=False) - new_root = depth_limit.limit_depth(expr, 2) - self.assertFalse(new_root.validate_step_format) - self.assertIsNone( - new_root.get_descendant(path.Path(["event", "doc", "bar"]))) - self.assertIsNone( - new_root.get_descendant(path.Path(["event", "doc", "keep_me"]))) - - self.assertIsNotNone(new_root.get_descendant(path.Path(["foo"]))) - self.assertIsNotNone(new_root.get_descendant(path.Path(["foorepeated"]))) - self.assertIsNotNone(new_root.get_descendant(path.Path(["user"]))) - self.assertIsNotNone( - new_root.get_descendant(path.Path(["user", "friends"]))) - self.assertIsNotNone(new_root.get_descendant(path.Path(["event"]))) - self.assertIsNotNone(new_root.get_descendant(path.Path(["event", "doc"]))) + def test_depth_limit_1(self): + """Tests depth_limit with a limit of 1. + + Starting with a prensor expression representing: + {foo:9, foorepeated:[9], user:[{friends:["a"]}], + event:{doc:[{bar:["a"], keep_me:False}]}} + {foo:8, foorepeated:[8,7], + event:{doc:[{bar:["b","c"], keep_me:True},{bar:["d"]}]}, + user:[{friends:["b", "c"]}, {friends:["d"]}]} + {foo:7, foorepeated:[6], user:[friends:["e"]], event:{}} + + After depth_limit.limit_depth(expr, 1), you lose event.doc, user.friends, + event.doc.bar, and event.doc.keep_me. + + {foo:9, foorepeated:[9], user:[{}], event:{}} + {foo:8, foorepeated:[8,7], event:{},user:[{}, {}]} + {foo:7, foorepeated:[6], user:[{}], event:{}} + + """ + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_deep_prensor() + ) + new_root = depth_limit.limit_depth(expr, 1) + + self.assertIsNone(new_root.get_descendant(path.Path(["event", "doc", "bar"]))) + self.assertIsNone( + new_root.get_descendant(path.Path(["event", "doc", "keep_me"])) + ) + + self.assertIsNone(new_root.get_descendant(path.Path(["user", "friends"]))) + self.assertIsNone(new_root.get_descendant(path.Path(["event", "doc"]))) + + self.assertIsNotNone(new_root.get_descendant(path.Path(["foo"]))) + self.assertIsNotNone(new_root.get_descendant(path.Path(["foorepeated"]))) + self.assertIsNotNone(new_root.get_descendant(path.Path(["user"]))) + self.assertIsNotNone(new_root.get_descendant(path.Path(["event"]))) + + def test_depth_limit_2(self): + """Tests depth_limit with a limit of 2. + + Starting with a prensor expression representing: + {foo:9, foorepeated:[9], user:[{friends:["a"]}], + event:{doc:[{bar:["a"], keep_me:False}]}} + {foo:8, foorepeated:[8,7], + event:{doc:[{bar:["b","c"], keep_me:True},{bar:["d"]}]}, + user:[{friends:["b", "c"]}, {friends:["d"]}]} + {foo:7, foorepeated:[6], user:[friends:["e"]], event:{}} + + After depth_limit.limit_depth(expr, 2), you lose event.doc.bar + and event.doc.keep_me: + + {foo:9, foorepeated:[9], user:[{friends:["a"]}], event:{doc:[{}]}} + {foo:8, foorepeated:[8,7], event:{doc:[{},{}]}, + user:[{friends:["b", "c"]}, {friends:["d"]}]} + {foo:7, foorepeated:[6], user:[friends:["e"]], event:{}} + + """ + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_deep_prensor(), validate_step_format=False + ) + new_root = depth_limit.limit_depth(expr, 2) + self.assertFalse(new_root.validate_step_format) + self.assertIsNone(new_root.get_descendant(path.Path(["event", "doc", "bar"]))) + self.assertIsNone( + new_root.get_descendant(path.Path(["event", "doc", "keep_me"])) + ) + + self.assertIsNotNone(new_root.get_descendant(path.Path(["foo"]))) + self.assertIsNotNone(new_root.get_descendant(path.Path(["foorepeated"]))) + self.assertIsNotNone(new_root.get_descendant(path.Path(["user"]))) + self.assertIsNotNone(new_root.get_descendant(path.Path(["user", "friends"]))) + self.assertIsNotNone(new_root.get_descendant(path.Path(["event"]))) + self.assertIsNotNone(new_root.get_descendant(path.Path(["event", "doc"]))) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/filter_expression.py b/struct2tensor/expression_impl/filter_expression.py index 4b7d026..f493d6f 100644 --- a/struct2tensor/expression_impl/filter_expression.py +++ b/struct2tensor/expression_impl/filter_expression.py @@ -68,309 +68,325 @@ from struct2tensor.ops import struct2tensor_ops -def filter_by_sibling(expr: expression.Expression, p: path.Path, - sibling_field_name: path.Step, - new_field_name: path.Step) -> expression.Expression: - """Filter an expression by its sibling. - - This is similar to boolean_mask. The shape of the path being filtered and - the sibling must be identical (e.g., each parent object must have an - equal number of source and sibling children). - - Args: - expr: the root expression. - p: a path to the source to be filtered. - sibling_field_name: the sibling to use as a mask. - new_field_name: a new sibling to create. - - Returns: - a new root. - """ - origin = expr.get_descendant_or_error(p) - parent_path = p.get_parent() - sibling = expr.get_descendant_or_error( - parent_path.get_child(sibling_field_name)) - new_expr = _FilterBySiblingExpression(origin, sibling) - new_path = parent_path.get_child(new_field_name) - return expression_add.add_paths(expr, {new_path: new_expr}) - - -def filter_by_child(expr: expression.Expression, p: path.Path, - child_field_name: path.Step, - new_field_name: path.Step) -> expression.Expression: - """Filter an expression by an optional boolean child field. - - If the child field is present and True, then keep that parent. - Otherwise, drop the parent. - - Args: - expr: the original expression - p: the path to filter. - child_field_name: the boolean child field to use to filter. - new_field_name: the new, filtered version of path. - - Returns: - The new root expression. - """ - origin = expr.get_descendant_or_error(p) - child = origin.get_child_or_error(child_field_name) - new_expr = _FilterByChildExpression(origin, child) - new_path = p.get_parent().get_child(new_field_name) - - return expression_add.add_paths(expr, {new_path: new_expr}) +def filter_by_sibling( + expr: expression.Expression, + p: path.Path, + sibling_field_name: path.Step, + new_field_name: path.Step, +) -> expression.Expression: + """Filter an expression by its sibling. + + This is similar to boolean_mask. The shape of the path being filtered and + the sibling must be identical (e.g., each parent object must have an + equal number of source and sibling children). + Args: + expr: the root expression. + p: a path to the source to be filtered. + sibling_field_name: the sibling to use as a mask. + new_field_name: a new sibling to create. -#################### Private methods and classes follow ######################## + Returns: + a new root. + """ + origin = expr.get_descendant_or_error(p) + parent_path = p.get_parent() + sibling = expr.get_descendant_or_error(parent_path.get_child(sibling_field_name)) + new_expr = _FilterBySiblingExpression(origin, sibling) + new_path = parent_path.get_child(new_field_name) + return expression_add.add_paths(expr, {new_path: new_expr}) -class _FilterRootNodeTensor(prensor.RootNodeTensor): - """The value of the root.""" +def filter_by_child( + expr: expression.Expression, + p: path.Path, + child_field_name: path.Step, + new_field_name: path.Step, +) -> expression.Expression: + """Filter an expression by an optional boolean child field. - def __init__(self, size: tf.Tensor, indices_to_keep: tf.Tensor): - """Initialize a root tensor that has indices_to_keep. + If the child field is present and True, then keep that parent. + Otherwise, drop the parent. Args: - size: an int64 scalar tensor - indices_to_keep: a 1D int64 tensor (int64 vector) - """ - super().__init__(size) - self._indices_to_keep = indices_to_keep + expr: the original expression + p: the path to filter. + child_field_name: the boolean child field to use to filter. + new_field_name: the new, filtered version of path. - @property - def indices_to_keep(self) -> tf.Tensor: - return self._indices_to_keep + Returns: + The new root expression. + """ + origin = expr.get_descendant_or_error(p) + child = origin.get_child_or_error(child_field_name) + new_expr = _FilterByChildExpression(origin, child) + new_path = p.get_parent().get_child(new_field_name) - def __str__(self): - return "_FilterRootNodeTensor" + return expression_add.add_paths(expr, {new_path: new_expr}) -class _FilterChildNodeTensor(prensor.ChildNodeTensor): - """The value of an intermediate node.""" +#################### Private methods and classes follow ######################## - def __init__(self, parent_index: tf.Tensor, is_repeated: bool, - indices_to_keep: tf.Tensor): - """Initialize a child node tensor with indices_to_keep. - Args: - parent_index: an int64 1D tensor (an int64 vector) - is_repeated: true if there can be more than one element per parent - indices_to_keep: an int64 1D tensor (an int64 vector) - """ - super().__init__(parent_index, is_repeated) - self._indices_to_keep = indices_to_keep - - def __str__(self): - return "{} FilterChildNode".format( - "repeated" if self.is_repeated else "optional") - - @property - def indices_to_keep(self) -> tf.Tensor: - return self._indices_to_keep - - -def _filter_by_self_indices_to_keep(node_value: prensor.NodeTensor, - self_indices_to_keep: tf.Tensor - ) -> prensor.NodeTensor: - """Filter the node by the indices you want to keep.""" - if isinstance(node_value, prensor.RootNodeTensor): - return _FilterRootNodeTensor( - tf.size(self_indices_to_keep), self_indices_to_keep) - if isinstance(node_value, prensor.ChildNodeTensor): - return _FilterChildNodeTensor( - tf.gather(node_value.parent_index, self_indices_to_keep), - node_value.is_repeated, self_indices_to_keep) - if isinstance(node_value, prensor.LeafNodeTensor): - return prensor.LeafNodeTensor( - tf.gather(node_value.parent_index, self_indices_to_keep), - tf.gather(node_value.values, self_indices_to_keep), - node_value.is_repeated) - raise ValueError("Unknown NodeValue type") +class _FilterRootNodeTensor(prensor.RootNodeTensor): + """The value of the root.""" + def __init__(self, size: tf.Tensor, indices_to_keep: tf.Tensor): + """Initialize a root tensor that has indices_to_keep. -def _filter_by_parent_indices_to_keep( - node_value: Union[prensor.ChildNodeTensor, prensor.LeafNodeTensor], - parent_indices_to_keep: tf.Tensor) -> prensor.NodeTensor: - """Filter by parent indices to keep.""" - [new_parent_index, self_indices_to_keep - ] = struct2tensor_ops.equi_join_indices(parent_indices_to_keep, - node_value.parent_index) - if isinstance(node_value, prensor.ChildNodeTensor): - return _FilterChildNodeTensor(new_parent_index, node_value.is_repeated, - self_indices_to_keep) - if isinstance(node_value, prensor.LeafNodeTensor): - return prensor.LeafNodeTensor( - new_parent_index, tf.gather(node_value.values, self_indices_to_keep), - node_value.is_repeated) - raise ValueError("Unknown NodeValue type") + Args: + size: an int64 scalar tensor + indices_to_keep: a 1D int64 tensor (int64 vector) + """ + super().__init__(size) + self._indices_to_keep = indices_to_keep + @property + def indices_to_keep(self) -> tf.Tensor: + return self._indices_to_keep -class _FilterChildByParentIndicesToKeepExpression(expression.Expression): - """Filter all descendants of a _FilterBy(Sibling/Child)Expression. - - This expression is used to represent the descendants of - _FilterBySiblingExpression and _FilterByChildExpression. - """ - - def __init__(self, origin: expression.Expression, - parent: expression.Expression): - super().__init__( - origin.is_repeated, - origin.type, - validate_step_format=origin.validate_step_format, - ) - self._origin = origin - self._parent = parent + def __str__(self): + return "_FilterRootNodeTensor" - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._origin, self._parent] - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - [origin_value, parent_value] = sources - if (not isinstance(parent_value, - (_FilterChildNodeTensor, _FilterRootNodeTensor))): - raise ValueError("Parent must be a filtered node") +class _FilterChildNodeTensor(prensor.ChildNodeTensor): + """The value of an intermediate node.""" + + def __init__( + self, parent_index: tf.Tensor, is_repeated: bool, indices_to_keep: tf.Tensor + ): + """Initialize a child node tensor with indices_to_keep. + + Args: + parent_index: an int64 1D tensor (an int64 vector) + is_repeated: true if there can be more than one element per parent + indices_to_keep: an int64 1D tensor (an int64 vector) + """ + super().__init__(parent_index, is_repeated) + self._indices_to_keep = indices_to_keep + + def __str__(self): + return "{} FilterChildNode".format( + "repeated" if self.is_repeated else "optional" + ) + + @property + def indices_to_keep(self) -> tf.Tensor: + return self._indices_to_keep + + +def _filter_by_self_indices_to_keep( + node_value: prensor.NodeTensor, self_indices_to_keep: tf.Tensor +) -> prensor.NodeTensor: + """Filter the node by the indices you want to keep.""" + if isinstance(node_value, prensor.RootNodeTensor): + return _FilterRootNodeTensor( + tf.size(self_indices_to_keep), self_indices_to_keep + ) + if isinstance(node_value, prensor.ChildNodeTensor): + return _FilterChildNodeTensor( + tf.gather(node_value.parent_index, self_indices_to_keep), + node_value.is_repeated, + self_indices_to_keep, + ) + if isinstance(node_value, prensor.LeafNodeTensor): + return prensor.LeafNodeTensor( + tf.gather(node_value.parent_index, self_indices_to_keep), + tf.gather(node_value.values, self_indices_to_keep), + node_value.is_repeated, + ) + raise ValueError("Unknown NodeValue type") - if (not isinstance(origin_value, - (prensor.ChildNodeTensor, prensor.LeafNodeTensor))): - raise ValueError("Original must be a child or leaf node") - parent_indices_to_keep = parent_value.indices_to_keep - return _filter_by_parent_indices_to_keep(origin_value, - parent_indices_to_keep) +def _filter_by_parent_indices_to_keep( + node_value: Union[prensor.ChildNodeTensor, prensor.LeafNodeTensor], + parent_indices_to_keep: tf.Tensor, +) -> prensor.NodeTensor: + """Filter by parent indices to keep.""" + [new_parent_index, self_indices_to_keep] = struct2tensor_ops.equi_join_indices( + parent_indices_to_keep, node_value.parent_index + ) + if isinstance(node_value, prensor.ChildNodeTensor): + return _FilterChildNodeTensor( + new_parent_index, node_value.is_repeated, self_indices_to_keep + ) + if isinstance(node_value, prensor.LeafNodeTensor): + return prensor.LeafNodeTensor( + new_parent_index, + tf.gather(node_value.values, self_indices_to_keep), + node_value.is_repeated, + ) + raise ValueError("Unknown NodeValue type") - def calculation_is_identity(self) -> bool: - return False - def calculation_equal(self, expr: expression.Expression) -> bool: - return isinstance(self, _FilterChildByParentIndicesToKeepExpression) +class _FilterChildByParentIndicesToKeepExpression(expression.Expression): + """Filter all descendants of a _FilterBy(Sibling/Child)Expression. - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - original_child = self._origin.get_child(field_name) - if original_child is None: - return None - return _FilterChildByParentIndicesToKeepExpression(original_child, self) + This expression is used to represent the descendants of + _FilterBySiblingExpression and _FilterByChildExpression. + """ - def known_field_names(self) -> FrozenSet[path.Step]: - return self._origin.known_field_names() + def __init__(self, origin: expression.Expression, parent: expression.Expression): + super().__init__( + origin.is_repeated, + origin.type, + validate_step_format=origin.validate_step_format, + ) + self._origin = origin + self._parent = parent + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._origin, self._parent] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + [origin_value, parent_value] = sources + if not isinstance( + parent_value, (_FilterChildNodeTensor, _FilterRootNodeTensor) + ): + raise ValueError("Parent must be a filtered node") + + if not isinstance( + origin_value, (prensor.ChildNodeTensor, prensor.LeafNodeTensor) + ): + raise ValueError("Original must be a child or leaf node") + + parent_indices_to_keep = parent_value.indices_to_keep + return _filter_by_parent_indices_to_keep(origin_value, parent_indices_to_keep) + + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + return isinstance(self, _FilterChildByParentIndicesToKeepExpression) + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + original_child = self._origin.get_child(field_name) + if original_child is None: + return None + return _FilterChildByParentIndicesToKeepExpression(original_child, self) + + def known_field_names(self) -> FrozenSet[path.Step]: + return self._origin.known_field_names() def _self_indices_where_true(leaf_node: prensor.LeafNodeTensor) -> tf.Tensor: - self_index = tf.range( - tf.size(leaf_node.parent_index, out_type=tf.int64), dtype=tf.int64) - return tf.boolean_mask(self_index, leaf_node.values) + self_index = tf.range( + tf.size(leaf_node.parent_index, out_type=tf.int64), dtype=tf.int64 + ) + return tf.boolean_mask(self_index, leaf_node.values) def _parent_indices_where_true(leaf_node: prensor.LeafNodeTensor) -> tf.Tensor: - return tf.boolean_mask(leaf_node.parent_index, leaf_node.values) + return tf.boolean_mask(leaf_node.parent_index, leaf_node.values) class _FilterBySiblingExpression(expression.Expression): - """Project all subfields of an expression.""" - - def __init__(self, origin: expression.Expression, - sibling: expression.Expression): - super().__init__( - origin.is_repeated, - origin.type, - validate_step_format=origin.validate_step_format, - ) - self._origin = origin - self._sibling = sibling - if sibling.type != tf.bool: - raise ValueError("Sibling must be a boolean leaf.") - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._origin, self._sibling] - - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - [origin_value, sibling_value] = sources - if not isinstance(origin_value, - (prensor.ChildNodeTensor, prensor.LeafNodeTensor)): - raise ValueError("Origin should not be a root") - if not isinstance(sibling_value, prensor.LeafNodeTensor): - raise ValueError("Sibling should be a leaf") - # Check that they are the same shape. - tf.assert_equal(sibling_value.parent_index, origin_value.parent_index) - self_indices_to_keep = _self_indices_where_true(sibling_value) - return _filter_by_self_indices_to_keep(origin_value, self_indices_to_keep) - - def calculation_is_identity(self) -> bool: - return False - - def calculation_equal(self, expr: expression.Expression) -> bool: - return isinstance(self, _FilterBySiblingExpression) - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - original = self._origin.get_child(field_name) - if original is None: - return None - return _FilterChildByParentIndicesToKeepExpression(original, self) - - def known_field_names(self) -> FrozenSet[path.Step]: - return self._origin.known_field_names() + """Project all subfields of an expression.""" + + def __init__(self, origin: expression.Expression, sibling: expression.Expression): + super().__init__( + origin.is_repeated, + origin.type, + validate_step_format=origin.validate_step_format, + ) + self._origin = origin + self._sibling = sibling + if sibling.type != tf.bool: + raise ValueError("Sibling must be a boolean leaf.") + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._origin, self._sibling] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + [origin_value, sibling_value] = sources + if not isinstance( + origin_value, (prensor.ChildNodeTensor, prensor.LeafNodeTensor) + ): + raise ValueError("Origin should not be a root") + if not isinstance(sibling_value, prensor.LeafNodeTensor): + raise ValueError("Sibling should be a leaf") + # Check that they are the same shape. + tf.assert_equal(sibling_value.parent_index, origin_value.parent_index) + self_indices_to_keep = _self_indices_where_true(sibling_value) + return _filter_by_self_indices_to_keep(origin_value, self_indices_to_keep) + + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + return isinstance(self, _FilterBySiblingExpression) + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + original = self._origin.get_child(field_name) + if original is None: + return None + return _FilterChildByParentIndicesToKeepExpression(original, self) + + def known_field_names(self) -> FrozenSet[path.Step]: + return self._origin.known_field_names() class _FilterByChildExpression(expression.Expression): - """Project all subfields of an expression.""" - - def __init__(self, origin: expression.Expression, - child: expression.Expression): - super().__init__( - origin.is_repeated, - origin.type, - validate_step_format=origin.validate_step_format, - ) - self._origin = origin - self._child = child - if child.type != tf.bool: - raise ValueError("Child must be boolean leaf") - if child.is_repeated: - raise ValueError("Child must not be repeated") - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._origin, self._child] - - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - [origin_value, child_value] = sources - if not isinstance(origin_value, - (prensor.ChildNodeTensor, prensor.LeafNodeTensor)): - raise ValueError("Origin is the wrong type.") - if not isinstance(child_value, prensor.LeafNodeTensor): - raise ValueError("Child is not a leaf node.") - # Here, we already know it is an optional field, so we don't need to check - # the shape of parent_index. - self_indices_to_keep = _parent_indices_where_true(child_value) - return _filter_by_self_indices_to_keep(origin_value, self_indices_to_keep) - - def calculation_is_identity(self) -> bool: - return False - - def calculation_equal(self, expr: expression.Expression) -> bool: - return isinstance(self, _FilterByChildExpression) - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - original = self._origin.get_child(field_name) - if original is None: - return None - return _FilterChildByParentIndicesToKeepExpression(original, self) - - def known_field_names(self) -> FrozenSet[path.Step]: - return self._origin.known_field_names() + """Project all subfields of an expression.""" + + def __init__(self, origin: expression.Expression, child: expression.Expression): + super().__init__( + origin.is_repeated, + origin.type, + validate_step_format=origin.validate_step_format, + ) + self._origin = origin + self._child = child + if child.type != tf.bool: + raise ValueError("Child must be boolean leaf") + if child.is_repeated: + raise ValueError("Child must not be repeated") + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._origin, self._child] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + [origin_value, child_value] = sources + if not isinstance( + origin_value, (prensor.ChildNodeTensor, prensor.LeafNodeTensor) + ): + raise ValueError("Origin is the wrong type.") + if not isinstance(child_value, prensor.LeafNodeTensor): + raise ValueError("Child is not a leaf node.") + # Here, we already know it is an optional field, so we don't need to check + # the shape of parent_index. + self_indices_to_keep = _parent_indices_where_true(child_value) + return _filter_by_self_indices_to_keep(origin_value, self_indices_to_keep) + + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + return isinstance(self, _FilterByChildExpression) + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + original = self._origin.get_child(field_name) + if original is None: + return None + return _FilterChildByParentIndicesToKeepExpression(original, self) + + def known_field_names(self) -> FrozenSet[path.Step]: + return self._origin.known_field_names() diff --git a/struct2tensor/expression_impl/filter_expression_test.py b/struct2tensor/expression_impl/filter_expression_test.py index 620ecfc..e514659 100644 --- a/struct2tensor/expression_impl/filter_expression_test.py +++ b/struct2tensor/expression_impl/filter_expression_test.py @@ -16,22 +16,22 @@ import tensorflow as tf from absl.testing import absltest from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) # For tf.Session.Run against a Prensor from struct2tensor import ( - calculate, - create_expression, - path, - prensor, # pylint: disable=unused-import + calculate, + create_expression, + path, + prensor, # pylint: disable=unused-import ) from struct2tensor.expression_impl import filter_expression, proto_test_util from struct2tensor.test import expression_test_util, prensor_test_util, test_pb2 def _create_slice_and_project_example(): - r"""Creates an example for test_slice_and_project. + r"""Creates an example for test_slice_and_project. Returns: ------*----------------- @@ -46,8 +46,9 @@ def _create_slice_and_project_example(): This also adds an action_mask that is false for the zeroeth action. """ - return proto_test_util.text_to_expression([ - """ + return proto_test_util.text_to_expression( + [ + """ event:{ action_mask: [false, true] action:{ @@ -73,7 +74,8 @@ def _create_slice_and_project_example(): action:{ doc_id:"f" } - }""", """ + }""", + """ event:{ action_mask: [false] action:{ @@ -91,12 +93,14 @@ def _create_slice_and_project_example(): action:{ doc_id:"j" } - }""" - ], test_pb2.Session) + }""", + ], + test_pb2.Session, + ) def _create_nested_prensor(): - r"""Creates a prensor representing a list of nested protocol buffers. + r"""Creates a prensor representing a list of nested protocol buffers. -----*---------------------------------------------------- / \ \ @@ -114,24 +118,25 @@ def _create_nested_prensor(): keep_me:True} {} """ - return prensor.create_prensor_from_descendant_nodes({ - path.Path([]): - prensor_test_util.create_root_node(3), - path.Path(["doc"]): - prensor_test_util.create_child_node([0, 1, 1], True), - path.Path(["keep_my_sib"]): - prensor_test_util.create_repeated_leaf_node([0, 1, 1], - [False, True, False]), - path.Path(["doc", "bar"]): - prensor_test_util.create_repeated_leaf_node([0, 1, 1, 2], - ["a", "b", "c", "d"]), - path.Path(["doc", "keep_me"]): - prensor_test_util.create_optional_leaf_node([0, 1], [False, True]) - }) + return prensor.create_prensor_from_descendant_nodes( + { + path.Path([]): prensor_test_util.create_root_node(3), + path.Path(["doc"]): prensor_test_util.create_child_node([0, 1, 1], True), + path.Path(["keep_my_sib"]): prensor_test_util.create_repeated_leaf_node( + [0, 1, 1], [False, True, False] + ), + path.Path(["doc", "bar"]): prensor_test_util.create_repeated_leaf_node( + [0, 1, 1, 2], ["a", "b", "c", "d"] + ), + path.Path(["doc", "keep_me"]): prensor_test_util.create_optional_leaf_node( + [0, 1], [False, True] + ), + } + ) def _create_nested_prensor_2(): - r"""Creates a prensor representing a list of nested protocol buffers. + r"""Creates a prensor representing a list of nested protocol buffers. keep_me no longer has a value in doc0. @@ -151,109 +156,139 @@ def _create_nested_prensor_2(): keep_me:True} {} """ - return prensor.create_prensor_from_descendant_nodes({ - path.Path([]): - prensor_test_util.create_root_node(3), - path.Path(["doc"]): - prensor_test_util.create_child_node([0, 1, 1], True), - path.Path(["keep_my_sib"]): - prensor_test_util.create_repeated_leaf_node([0, 1, 1], - [False, True, False]), - path.Path(["doc", "bar"]): - prensor_test_util.create_repeated_leaf_node([0, 1, 1, 2], - ["a", "b", "c", "d"]), - path.Path(["doc", "keep_me"]): - prensor_test_util.create_optional_leaf_node([1], [True]) - }) + return prensor.create_prensor_from_descendant_nodes( + { + path.Path([]): prensor_test_util.create_root_node(3), + path.Path(["doc"]): prensor_test_util.create_child_node([0, 1, 1], True), + path.Path(["keep_my_sib"]): prensor_test_util.create_repeated_leaf_node( + [0, 1, 1], [False, True, False] + ), + path.Path(["doc", "bar"]): prensor_test_util.create_repeated_leaf_node( + [0, 1, 1, 2], ["a", "b", "c", "d"] + ), + path.Path(["doc", "keep_me"]): prensor_test_util.create_optional_leaf_node( + [1], [True] + ), + } + ) @test_util.run_all_in_graph_and_eager_modes class FilterExpressionTest(tf.test.TestCase): - - def test_filter_by_child(self): - """Tests filter_by_child.""" - root = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor(), validate_step_format=False) - root_2 = filter_expression.filter_by_child(root, path.create_path("doc"), - "keep_me", "new_doc") - self.assertFalse(root_2.validate_step_format) - [result] = calculate.calculate_prensors([root_2]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc" - ])).node.parent_index, [1]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", "keep_me" - ])).node.parent_index, [0]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", - "keep_me"])).node.values, - [True]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", - "bar"])).node.parent_index, - [0, 0]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", - "bar"])).node.values, - [b"b", b"c"]) - - def test_filter_by_child_create_nested_prensor(self): - """Tests filter_by_child.""" - root = create_expression.create_expression_from_prensor( - _create_nested_prensor()) - root_2 = filter_expression.filter_by_child(root, path.create_path("doc"), - "keep_me", "new_doc") - [result] = calculate.calculate_prensors([root_2]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc" - ])).node.parent_index, [1]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", "keep_me" - ])).node.parent_index, [0]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", - "keep_me"])).node.values, - [True]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", - "bar"])).node.parent_index, - [0, 0]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", - "bar"])).node.values, - [b"b", b"c"]) - - def test_filter_by_child_create_nested_prensor_2(self): - """Tests filter_by_child. - - In particular, it checks for the case where parent_index != self index. - """ - root = create_expression.create_expression_from_prensor( - _create_nested_prensor_2()) - root_2 = filter_expression.filter_by_child(root, path.create_path("doc"), - "keep_me", "new_doc") - [result] = calculate.calculate_prensors([root_2]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc" - ])).node.parent_index, [1]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", "keep_me" - ])).node.parent_index, [0]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", - "keep_me"])).node.values, - [True]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", - "bar"])).node.parent_index, - [0, 0]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", - "bar"])).node.values, - [b"b", b"c"]) - - def test_filter_by_sibling(self): - r"""Tests filter_by_sibling. + def test_filter_by_child(self): + """Tests filter_by_child.""" + root = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor(), validate_step_format=False + ) + root_2 = filter_expression.filter_by_child( + root, path.create_path("doc"), "keep_me", "new_doc" + ) + self.assertFalse(root_2.validate_step_format) + [result] = calculate.calculate_prensors([root_2]) + self.assertAllEqual( + result.get_descendant_or_error(path.Path(["new_doc"])).node.parent_index, + [1], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new_doc", "keep_me"]) + ).node.parent_index, + [0], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new_doc", "keep_me"]) + ).node.values, + [True], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new_doc", "bar"]) + ).node.parent_index, + [0, 0], + ) + self.assertAllEqual( + result.get_descendant_or_error(path.Path(["new_doc", "bar"])).node.values, + [b"b", b"c"], + ) + + def test_filter_by_child_create_nested_prensor(self): + """Tests filter_by_child.""" + root = create_expression.create_expression_from_prensor( + _create_nested_prensor() + ) + root_2 = filter_expression.filter_by_child( + root, path.create_path("doc"), "keep_me", "new_doc" + ) + [result] = calculate.calculate_prensors([root_2]) + self.assertAllEqual( + result.get_descendant_or_error(path.Path(["new_doc"])).node.parent_index, + [1], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new_doc", "keep_me"]) + ).node.parent_index, + [0], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new_doc", "keep_me"]) + ).node.values, + [True], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new_doc", "bar"]) + ).node.parent_index, + [0, 0], + ) + self.assertAllEqual( + result.get_descendant_or_error(path.Path(["new_doc", "bar"])).node.values, + [b"b", b"c"], + ) + + def test_filter_by_child_create_nested_prensor_2(self): + """Tests filter_by_child. + + In particular, it checks for the case where parent_index != self index. + """ + root = create_expression.create_expression_from_prensor( + _create_nested_prensor_2() + ) + root_2 = filter_expression.filter_by_child( + root, path.create_path("doc"), "keep_me", "new_doc" + ) + [result] = calculate.calculate_prensors([root_2]) + self.assertAllEqual( + result.get_descendant_or_error(path.Path(["new_doc"])).node.parent_index, + [1], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new_doc", "keep_me"]) + ).node.parent_index, + [0], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new_doc", "keep_me"]) + ).node.values, + [True], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new_doc", "bar"]) + ).node.parent_index, + [0, 0], + ) + self.assertAllEqual( + result.get_descendant_or_error(path.Path(["new_doc", "bar"])).node.values, + [b"b", b"c"], + ) + + def test_filter_by_sibling(self): + r"""Tests filter_by_sibling. Beginning with the struct: -----*---------------------------------------------------- @@ -278,55 +313,66 @@ def test_filter_by_sibling(self): bar:"b" bar:"c" keep_me:True """ - root = create_expression.create_expression_from_prensor( - _create_nested_prensor()) - root_2 = filter_expression.filter_by_sibling(root, path.create_path("doc"), - "keep_my_sib", "new_doc") - [result] = calculate.calculate_prensors([root_2]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc" - ])).node.parent_index, [1]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", "keep_me" - ])).node.parent_index, [0]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", - "keep_me"])).node.values, - [True]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", - "bar"])).node.parent_index, - [0, 0]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", - "bar"])).node.values, - [b"b", b"c"]) - - def test_slice_and_project_mini(self): - """Testing a part of query_test.test_slice_and_project. - - Originally, there was an error with query_test.test_slice_and_project, - caused by filter_expression. I used this unit test to find and ultimately - fix the error. - """ - root = _create_slice_and_project_example() - - root_2 = filter_expression.filter_by_sibling(root, - path.Path(["event", "action"]), - "action_mask", "taction") - calculate_value = expression_test_util.calculate_value_slowly( - root_2.get_descendant_or_error(path.Path(["event", "taction"]))) - value_indices = calculate_value.parent_index - self.assertAllEqual(value_indices, [0, 1, 2, 4, 4]) - - def test_indices_where_true(self): - input_prensor_node = prensor_test_util.create_repeated_leaf_node( - [0, 0, 1, 1, 2, 2, 3, 4, 4, 4], - [False, True, False, True, False, True, False, False, True, True]) - tensor_result = filter_expression._self_indices_where_true( - input_prensor_node) - self.assertAllEqual(tensor_result, [1, 3, 5, 8, 9]) + root = create_expression.create_expression_from_prensor( + _create_nested_prensor() + ) + root_2 = filter_expression.filter_by_sibling( + root, path.create_path("doc"), "keep_my_sib", "new_doc" + ) + [result] = calculate.calculate_prensors([root_2]) + self.assertAllEqual( + result.get_descendant_or_error(path.Path(["new_doc"])).node.parent_index, + [1], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new_doc", "keep_me"]) + ).node.parent_index, + [0], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new_doc", "keep_me"]) + ).node.values, + [True], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new_doc", "bar"]) + ).node.parent_index, + [0, 0], + ) + self.assertAllEqual( + result.get_descendant_or_error(path.Path(["new_doc", "bar"])).node.values, + [b"b", b"c"], + ) + + def test_slice_and_project_mini(self): + """Testing a part of query_test.test_slice_and_project. + + Originally, there was an error with query_test.test_slice_and_project, + caused by filter_expression. I used this unit test to find and ultimately + fix the error. + """ + root = _create_slice_and_project_example() + + root_2 = filter_expression.filter_by_sibling( + root, path.Path(["event", "action"]), "action_mask", "taction" + ) + calculate_value = expression_test_util.calculate_value_slowly( + root_2.get_descendant_or_error(path.Path(["event", "taction"])) + ) + value_indices = calculate_value.parent_index + self.assertAllEqual(value_indices, [0, 1, 2, 4, 4]) + + def test_indices_where_true(self): + input_prensor_node = prensor_test_util.create_repeated_leaf_node( + [0, 0, 1, 1, 2, 2, 3, 4, 4, 4], + [False, True, False, True, False, True, False, False, True, True], + ) + tensor_result = filter_expression._self_indices_where_true(input_prensor_node) + self.assertAllEqual(tensor_result, [1, 3, 5, 8, 9]) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/index.py b/struct2tensor/expression_impl/index.py index 5154c21..4b11017 100644 --- a/struct2tensor/expression_impl/index.py +++ b/struct2tensor/expression_impl/index.py @@ -124,145 +124,154 @@ from struct2tensor.expression_impl import size -def get_positional_index(expr: expression.Expression, source_path: path.Path, - new_field_name: path.Step - ) -> Tuple[expression.Expression, path.Path]: - """Gets the positional index. - - Given a field with parent_index [0,1,1,2,3,4,4], this returns: - parent_index [0,1,1,2,3,4,4] and value [0,0,1,0,0,0,1] - - Args: - expr: original expression - source_path: path in expression to get index of. - new_field_name: the name of the new field. - - Returns: - The new expression and the new path as a pair. - """ - new_path = source_path.get_parent().get_child(new_field_name) - return expression_add.add_paths( - expr, { - new_path: - _PositionalIndexExpression( - expr.get_descendant_or_error(source_path)) - }), new_path - - -def get_index_from_end(t: expression.Expression, source_path: path.Path, - new_field_name: path.Step - ) -> Tuple[expression.Expression, path.Path]: - """Gets the number of steps from the end of the array. - - Given an array ["a", "b", "c"], with indices [0, 1, 2], the result of this - is [-3,-2,-1]. - - Args: - t: original expression - source_path: path in expression to get index of. - new_field_name: the name of the new field. - - Returns: - The new expression and the new path as a pair. - """ - new_path = source_path.get_parent().get_child(new_field_name) - work_expr, positional_index_path = get_positional_index( - t, source_path, path.get_anonymous_field()) - work_expr, size_path = size.size_anonymous(work_expr, source_path) - work_expr = expression_add.add_paths( - work_expr, { - new_path: - _PositionalIndexFromEndExpression( - work_expr.get_descendant_or_error(positional_index_path), - work_expr.get_descendant_or_error(size_path)) - }) - # Removing the intermediate anonymous nodes. - result = expression_add.add_to(t, {new_path: work_expr}) - return result, new_path +def get_positional_index( + expr: expression.Expression, source_path: path.Path, new_field_name: path.Step +) -> Tuple[expression.Expression, path.Path]: + """Gets the positional index. + Given a field with parent_index [0,1,1,2,3,4,4], this returns: + parent_index [0,1,1,2,3,4,4] and value [0,0,1,0,0,0,1] -class _PositionalIndexExpression(expression.Leaf): - """The positional index for the origin. + Args: + expr: original expression + source_path: path in expression to get index of. + new_field_name: the name of the new field. + + Returns: + The new expression and the new path as a pair. + """ + new_path = source_path.get_parent().get_child(new_field_name) + return expression_add.add_paths( + expr, + { + new_path: _PositionalIndexExpression( + expr.get_descendant_or_error(source_path) + ) + }, + ), new_path - _PositionalIndexExpression is intended to be a sibling of source. - The operation will return a field that has the same parent index as source. - """ - def __init__(self, origin: expression.Expression): - super().__init__(origin.is_repeated, tf.int64) - self._origin = origin +def get_index_from_end( + t: expression.Expression, source_path: path.Path, new_field_name: path.Step +) -> Tuple[expression.Expression, path.Path]: + """Gets the number of steps from the end of the array. - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._origin] + Given an array ["a", "b", "c"], with indices [0, 1, 2], the result of this + is [-3,-2,-1]. - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - [origin] = sources - if isinstance(origin, (prensor.LeafNodeTensor, prensor.ChildNodeTensor)): - return prensor.LeafNodeTensor( - origin.parent_index, - origin.get_positional_index(), - self.is_repeated) - raise ValueError("Cannot calculate the positional index of the root") + Args: + t: original expression + source_path: path in expression to get index of. + new_field_name: the name of the new field. - def calculation_is_identity(self) -> bool: - return False + Returns: + The new expression and the new path as a pair. + """ + new_path = source_path.get_parent().get_child(new_field_name) + work_expr, positional_index_path = get_positional_index( + t, source_path, path.get_anonymous_field() + ) + work_expr, size_path = size.size_anonymous(work_expr, source_path) + work_expr = expression_add.add_paths( + work_expr, + { + new_path: _PositionalIndexFromEndExpression( + work_expr.get_descendant_or_error(positional_index_path), + work_expr.get_descendant_or_error(size_path), + ) + }, + ) + # Removing the intermediate anonymous nodes. + result = expression_add.add_to(t, {new_path: work_expr}) + return result, new_path - def calculation_equal(self, expr: expression.Expression) -> bool: - return isinstance(expr, _PositionalIndexExpression) +class _PositionalIndexExpression(expression.Leaf): + """The positional index for the origin. -class _PositionalIndexFromEndExpression(expression.Leaf): - """The positional index from the end. + _PositionalIndexExpression is intended to be a sibling of source. + The operation will return a field that has the same parent index as source. + """ - _PositionalIndexFromEndExpression is intended to be a sibling of - positional_index. - The operation will return a field that has the same parent index as source. - """ + def __init__(self, origin: expression.Expression): + super().__init__(origin.is_repeated, tf.int64) + self._origin = origin - def __init__(self, positional_index: expression.Expression, - size_inp: expression.Expression): - """Create a new expression. + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._origin] - Note that while positional_index and size are both siblings of the - original field, positional_index is index-aligned with the field, - whereas size is a "required" field. + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + [origin] = sources + if isinstance(origin, (prensor.LeafNodeTensor, prensor.ChildNodeTensor)): + return prensor.LeafNodeTensor( + origin.parent_index, origin.get_positional_index(), self.is_repeated + ) + raise ValueError("Cannot calculate the positional index of the root") - Args: - positional_index: the positional index of the field (from - get_positional_index). - size_inp: the size of the field (from size.size). + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + return isinstance(expr, _PositionalIndexExpression) + + +class _PositionalIndexFromEndExpression(expression.Leaf): + """The positional index from the end. + + _PositionalIndexFromEndExpression is intended to be a sibling of + positional_index. + The operation will return a field that has the same parent index as source. """ - super().__init__(positional_index.is_repeated, tf.int64) - self._positional_index = positional_index - self._size = size_inp - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._positional_index, self._size] - - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - [positional_index, size_value] = sources - if not isinstance(positional_index, prensor.LeafNodeTensor): - raise ValueError("positional_index must be a LeafNodeTensor") - if not isinstance(size_value, prensor.LeafNodeTensor): - raise ValueError("size_value must be a LeafNodeTensor") - - size_per_index = tf.gather(size_value.values, positional_index.parent_index) - return prensor.LeafNodeTensor(positional_index.parent_index, - positional_index.values - size_per_index, - self.is_repeated) - - def calculation_is_identity(self) -> bool: - return False - - def calculation_equal(self, expr: expression.Expression) -> bool: - return isinstance(expr, _PositionalIndexFromEndExpression) + + def __init__( + self, positional_index: expression.Expression, size_inp: expression.Expression + ): + """Create a new expression. + + Note that while positional_index and size are both siblings of the + original field, positional_index is index-aligned with the field, + whereas size is a "required" field. + + Args: + positional_index: the positional index of the field (from + get_positional_index). + size_inp: the size of the field (from size.size). + """ + super().__init__(positional_index.is_repeated, tf.int64) + self._positional_index = positional_index + self._size = size_inp + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._positional_index, self._size] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + [positional_index, size_value] = sources + if not isinstance(positional_index, prensor.LeafNodeTensor): + raise ValueError("positional_index must be a LeafNodeTensor") + if not isinstance(size_value, prensor.LeafNodeTensor): + raise ValueError("size_value must be a LeafNodeTensor") + + size_per_index = tf.gather(size_value.values, positional_index.parent_index) + return prensor.LeafNodeTensor( + positional_index.parent_index, + positional_index.values - size_per_index, + self.is_repeated, + ) + + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + return isinstance(expr, _PositionalIndexFromEndExpression) diff --git a/struct2tensor/expression_impl/index_test.py b/struct2tensor/expression_impl/index_test.py index 295a6ad..e5a3bd3 100644 --- a/struct2tensor/expression_impl/index_test.py +++ b/struct2tensor/expression_impl/index_test.py @@ -16,7 +16,7 @@ import tensorflow as tf from absl.testing import absltest from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor import create_expression, path @@ -25,66 +25,72 @@ class IndexTest(absltest.TestCase): + def test_get_positional_index(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor() + ) + new_root, new_path = index.get_positional_index( + expr, path.Path(["user", "friends"]), path.get_anonymous_field() + ) + new_field = new_root.get_descendant_or_error(new_path) + self.assertTrue(new_field.is_repeated) + self.assertEqual(new_field.type, tf.int64) + self.assertTrue(new_field.is_leaf) + self.assertTrue(new_field.calculation_equal(new_field)) + self.assertFalse(new_field.calculation_equal(expr)) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertEqual(leaf_node.values.dtype, tf.int64) + self.assertEqual(new_field.known_field_names(), frozenset()) - def test_get_positional_index(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor()) - new_root, new_path = index.get_positional_index( - expr, path.Path(["user", "friends"]), path.get_anonymous_field()) - new_field = new_root.get_descendant_or_error(new_path) - self.assertTrue(new_field.is_repeated) - self.assertEqual(new_field.type, tf.int64) - self.assertTrue(new_field.is_leaf) - self.assertTrue(new_field.calculation_equal(new_field)) - self.assertFalse(new_field.calculation_equal(expr)) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertEqual(leaf_node.values.dtype, tf.int64) - self.assertEqual(new_field.known_field_names(), frozenset()) - - def test_get_index_from_end(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor()) - new_root, new_path = index.get_index_from_end( - expr, path.Path(["user", "friends"]), path.get_anonymous_field()) - new_field = new_root.get_descendant_or_error(new_path) - self.assertTrue(new_field.is_repeated) - self.assertEqual(new_field.type, tf.int64) - self.assertTrue(new_field.is_leaf) - self.assertTrue(new_field.calculation_equal(new_field)) - self.assertFalse(new_field.calculation_equal(expr)) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertEqual(leaf_node.values.dtype, tf.int64) - self.assertEqual(new_field.known_field_names(), frozenset()) + def test_get_index_from_end(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor() + ) + new_root, new_path = index.get_index_from_end( + expr, path.Path(["user", "friends"]), path.get_anonymous_field() + ) + new_field = new_root.get_descendant_or_error(new_path) + self.assertTrue(new_field.is_repeated) + self.assertEqual(new_field.type, tf.int64) + self.assertTrue(new_field.is_leaf) + self.assertTrue(new_field.calculation_equal(new_field)) + self.assertFalse(new_field.calculation_equal(expr)) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertEqual(leaf_node.values.dtype, tf.int64) + self.assertEqual(new_field.known_field_names(), frozenset()) @test_util.run_all_in_graph_and_eager_modes class GetIndexValuesTest(tf.test.TestCase): + def test_get_positional_index_calculate(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor() + ) + new_root, new_path = index.get_positional_index( + expr, path.Path(["user", "friends"]), path.get_anonymous_field() + ) + new_field = new_root.get_descendant_or_error(new_path) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 2, 3]) + self.assertAllEqual(leaf_node.values, [0, 0, 1, 0, 0]) - def test_get_positional_index_calculate(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor()) - new_root, new_path = index.get_positional_index( - expr, path.Path(["user", "friends"]), path.get_anonymous_field()) - new_field = new_root.get_descendant_or_error(new_path) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 2, 3]) - self.assertAllEqual(leaf_node.values, [0, 0, 1, 0, 0]) - - def test_get_index_from_end_calculate(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor()) - new_root, new_path = index.get_index_from_end( - expr, path.Path(["user", "friends"]), path.get_anonymous_field()) - print(f"test_get_index_from_end_calculate: new_path: {new_path}") - new_field = new_root.get_descendant_or_error(new_path) - print(f"test_get_index_from_end_calculate: new_field: {str(new_field)}") + def test_get_index_from_end_calculate(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor() + ) + new_root, new_path = index.get_index_from_end( + expr, path.Path(["user", "friends"]), path.get_anonymous_field() + ) + print(f"test_get_index_from_end_calculate: new_path: {new_path}") + new_field = new_root.get_descendant_or_error(new_path) + print(f"test_get_index_from_end_calculate: new_field: {str(new_field)}") - leaf_node = expression_test_util.calculate_value_slowly(new_field) - print(f"test_get_index_from_end_calculate: leaf_node: {str(leaf_node)}") + leaf_node = expression_test_util.calculate_value_slowly(new_field) + print(f"test_get_index_from_end_calculate: leaf_node: {str(leaf_node)}") - self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 2, 3]) - self.assertAllEqual(leaf_node.values, [-1, -2, -1, -1, -1]) + self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 2, 3]) + self.assertAllEqual(leaf_node.values, [-1, -2, -1, -1, -1]) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/map_prensor.py b/struct2tensor/expression_impl/map_prensor.py index f226b7b..ac62d5f 100644 --- a/struct2tensor/expression_impl/map_prensor.py +++ b/struct2tensor/expression_impl/map_prensor.py @@ -107,261 +107,306 @@ from struct2tensor.expression_impl import project -def map_sparse_tensor(root: expression.Expression, root_path: path.Path, - paths: Sequence[path.Path], - operation: Callable[..., tf.SparseTensor], - is_repeated: bool, dtype: tf.DType, - new_field_name: path.Step) -> expression.Expression: - """Maps a sparse tensor. - - Args: - root: the root of the expression. - root_path: the path relative to which the sparse tensors are calculated. - paths: the input paths relative to the root_path - operation: a method that takes the list of sparse tensors as input and - returns a sparse tensor. - is_repeated: true if the result of operation is repeated. - dtype: dtype of the result of the operation. - new_field_name: root_path.get_child(new_field_name) is the path of the - result. - - Returns: - A new root expression containing the old root expression plus the new path, - root_path.get_child(new_field_name), with the result of the operation. - """ - return _map_sparse_tensor_impl(root, root_path, paths, operation, is_repeated, - dtype, new_field_name)[0] - - -def map_ragged_tensor(root: expression.Expression, root_path: path.Path, - paths: Sequence[path.Path], - operation: Callable[..., tf.RaggedTensor], - is_repeated: bool, dtype: tf.DType, - new_field_name: path.Step) -> expression.Expression: - """Map a ragged tensor. - - Args: - root: the root of the expression. - root_path: the path relative to which the ragged tensors are calculated. - paths: the input paths relative to the root_path - operation: a method that takes the list of ragged tensors as input and - returns a ragged tensor. - is_repeated: true if the result of operation is repeated. - dtype: dtype of the result of the operation. - new_field_name: root_path.get_child(new_field_name) is the path of the - result. - - Returns: - A new root expression containing the old root expression plus the new path, - root_path.get_child(new_field_name), with the result of the operation. - """ - return _map_ragged_tensor_impl(root, root_path, paths, operation, is_repeated, - dtype, new_field_name)[0] +def map_sparse_tensor( + root: expression.Expression, + root_path: path.Path, + paths: Sequence[path.Path], + operation: Callable[..., tf.SparseTensor], + is_repeated: bool, + dtype: tf.DType, + new_field_name: path.Step, +) -> expression.Expression: + """Maps a sparse tensor. + + Args: + root: the root of the expression. + root_path: the path relative to which the sparse tensors are calculated. + paths: the input paths relative to the root_path + operation: a method that takes the list of sparse tensors as input and + returns a sparse tensor. + is_repeated: true if the result of operation is repeated. + dtype: dtype of the result of the operation. + new_field_name: root_path.get_child(new_field_name) is the path of the + result. + + Returns: + A new root expression containing the old root expression plus the new path, + root_path.get_child(new_field_name), with the result of the operation. + """ + return _map_sparse_tensor_impl( + root, root_path, paths, operation, is_repeated, dtype, new_field_name + )[0] + + +def map_ragged_tensor( + root: expression.Expression, + root_path: path.Path, + paths: Sequence[path.Path], + operation: Callable[..., tf.RaggedTensor], + is_repeated: bool, + dtype: tf.DType, + new_field_name: path.Step, +) -> expression.Expression: + """Map a ragged tensor. + + Args: + root: the root of the expression. + root_path: the path relative to which the ragged tensors are calculated. + paths: the input paths relative to the root_path + operation: a method that takes the list of ragged tensors as input and + returns a ragged tensor. + is_repeated: true if the result of operation is repeated. + dtype: dtype of the result of the operation. + new_field_name: root_path.get_child(new_field_name) is the path of the + result. + + Returns: + A new root expression containing the old root expression plus the new path, + root_path.get_child(new_field_name), with the result of the operation. + """ + return _map_ragged_tensor_impl( + root, root_path, paths, operation, is_repeated, dtype, new_field_name + )[0] class _MapPrensorExpression(expression.Expression): - """Maps the values of the given expression. + """Maps the values of the given expression. + + It maps the value of a sub-tree (i.e. a Prensor) to a single prensor + LeafNodeTensor. Therefore its sources are all the (known) descendants of + `origin`: it usually should follow a project(...) to make known descendants + clear. + + _MapPrensorExpression is intended to be a child of the origin. See + map_prensor_impl for example usage. + + """ + + def __init__( + self, + origin: expression.Expression, + operation: Callable[ + [prensor.Prensor, calculate_options.Options], prensor.LeafNodeTensor + ], + is_repeated: bool, + dtype: tf.DType, + ): + super().__init__( + is_repeated, dtype, validate_step_format=origin.validate_step_format + ) + self._origin = origin + self._operation = operation + + def _get_source_paths(self) -> Sequence[path.Path]: + """Returns the source paths in a deterministic order.""" + result = [k for k in self._origin.get_known_descendants()] + result.sort() + return result + + def get_source_expressions(self) -> Sequence[expression.Expression]: + subtree = self._origin.get_known_descendants() + source_paths = self._get_source_paths() + return [subtree[k] for k in source_paths] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.LeafNodeTensor: + source_tree = prensor.create_prensor_from_descendant_nodes( + {k: v for k, v in zip(self._get_source_paths(), sources)} + ) + return self._operation(source_tree, options) + + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + return self is expr + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + return None + + def known_field_names(self) -> FrozenSet[path.Step]: + return frozenset() + + +def _as_leaf_node_no_checks( + sparse_tensor: tf.SparseTensor, is_repeated: bool +) -> prensor.LeafNodeTensor: + """Take a SparseTensor and create a LeafNodeTensor, no checks.""" + if is_repeated: + parent_index = tf.transpose(sparse_tensor.indices)[0] + else: + parent_index = tf.reshape(sparse_tensor.indices, [-1]) + return prensor.LeafNodeTensor(parent_index, sparse_tensor.values, is_repeated) + + +def _as_leaf_node_with_checks( + sparse_tensor: tf.SparseTensor, is_repeated: bool, required_batch_size: tf.Tensor +) -> prensor.LeafNodeTensor: + """Take a SparseTensor and create a LeafNodeTensor, with checks.""" + assertions = [tf.assert_equal(sparse_tensor.dense_shape[0], required_batch_size)] + if is_repeated: + assertions.append(tf.assert_equal(tf.shape(sparse_tensor.indices)[1], 2)) + else: + assertions.append(tf.assert_equal(tf.shape(sparse_tensor.indices)[1], 1)) - It maps the value of a sub-tree (i.e. a Prensor) to a single prensor - LeafNodeTensor. Therefore its sources are all the (known) descendants of - `origin`: it usually should follow a project(...) to make known descendants - clear. - - _MapPrensorExpression is intended to be a child of the origin. See - map_prensor_impl for example usage. - - """ - - def __init__(self, origin: expression.Expression, - operation: Callable[[prensor.Prensor, calculate_options - .Options], prensor.LeafNodeTensor], - is_repeated: bool, dtype: tf.DType): - super().__init__( - is_repeated, dtype, validate_step_format=origin.validate_step_format - ) - self._origin = origin - self._operation = operation - - def _get_source_paths(self) -> Sequence[path.Path]: - """Returns the source paths in a deterministic order.""" - result = [k for k in self._origin.get_known_descendants()] - result.sort() - return result - - def get_source_expressions(self) -> Sequence[expression.Expression]: - subtree = self._origin.get_known_descendants() - source_paths = self._get_source_paths() - return [subtree[k] for k in source_paths] - - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.LeafNodeTensor: - source_tree = prensor.create_prensor_from_descendant_nodes( - {k: v for k, v in zip(self._get_source_paths(), sources)}) - return self._operation(source_tree, options) - - def calculation_is_identity(self) -> bool: - return False - - def calculation_equal(self, expr: expression.Expression) -> bool: - return self is expr - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - return None - - def known_field_names(self) -> FrozenSet[path.Step]: - return frozenset() - - -def _as_leaf_node_no_checks(sparse_tensor: tf.SparseTensor, - is_repeated: bool) -> prensor.LeafNodeTensor: - """Take a SparseTensor and create a LeafNodeTensor, no checks.""" - if is_repeated: - parent_index = tf.transpose(sparse_tensor.indices)[0] - else: - parent_index = tf.reshape(sparse_tensor.indices, [-1]) - return prensor.LeafNodeTensor(parent_index, sparse_tensor.values, is_repeated) - - -def _as_leaf_node_with_checks(sparse_tensor: tf.SparseTensor, is_repeated: bool, - required_batch_size: tf.Tensor - ) -> prensor.LeafNodeTensor: - """Take a SparseTensor and create a LeafNodeTensor, with checks.""" - assertions = [ - tf.assert_equal(sparse_tensor.dense_shape[0], required_batch_size) - ] - if is_repeated: - assertions.append(tf.assert_equal(tf.shape(sparse_tensor.indices)[1], 2)) - else: - assertions.append(tf.assert_equal(tf.shape(sparse_tensor.indices)[1], 1)) - - with tf.control_dependencies(assertions): - # TODO(b/72947444): Check that the resulting tensor is canonical, that the - # indices are in lexicographical order, and that the indices fit in the - # shape. Moreover, maybe we should check if it is repeated that it is a - # "ragged array". + with tf.control_dependencies(assertions): + # TODO(b/72947444): Check that the resulting tensor is canonical, that the + # indices are in lexicographical order, and that the indices fit in the + # shape. Moreover, maybe we should check if it is repeated that it is a + # "ragged array". + return _as_leaf_node_no_checks(sparse_tensor, is_repeated) + + +def _as_leaf_node( + sparse_tensor: tf.SparseTensor, + is_repeated: bool, + required_batch_size: tf.Tensor, + options: calculate_options.Options, +) -> prensor.LeafNodeTensor: + if options.sparse_checks: + return _as_leaf_node_with_checks( + sparse_tensor, is_repeated, required_batch_size + ) return _as_leaf_node_no_checks(sparse_tensor, is_repeated) -def _as_leaf_node(sparse_tensor: tf.SparseTensor, is_repeated: bool, - required_batch_size: tf.Tensor, - options: calculate_options.Options) -> prensor.LeafNodeTensor: - if options.sparse_checks: - return _as_leaf_node_with_checks(sparse_tensor, is_repeated, - required_batch_size) - return _as_leaf_node_no_checks(sparse_tensor, is_repeated) - - def _map_prensor_impl( - root: expression.Expression, root_path: path.Path, + root: expression.Expression, + root_path: path.Path, paths_needed: Sequence[path.Path], - operation: Callable[[prensor.Prensor, calculate_options.Options], prensor - .LeafNodeTensor], is_repeated: bool, dtype: tf.DType, - new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: - """Map prensor implementation.""" - child_expr = root.get_descendant_or_error(root_path) - sibling_child_expr = project.project(child_expr, paths_needed) - new_field_expr = _MapPrensorExpression(sibling_child_expr, operation, - is_repeated, dtype) - new_path = root_path.get_child(new_field_name) - return expression_add.add_paths(root, {new_path: new_field_expr}), new_path - - -def _map_sparse_tensor_impl(root: expression.Expression, root_path: path.Path, - paths: Sequence[path.Path], - operation: Callable[..., tf.SparseTensor], - is_repeated: bool, dtype: tf.DType, - new_field_name: path.Step - ) -> Tuple[expression.Expression, path.Path]: - """Helper method for map_sparse_tensor.""" - - def new_op(pren: prensor.Prensor, - options: calculate_options.Options) -> prensor.LeafNodeTensor: - """Op for mapping prensor using the operation.""" - sparse_tensor_map = pren.get_sparse_tensors(options) - sparse_tensors = [sparse_tensor_map[p] for p in paths] - result_as_tensor = operation(*sparse_tensors) - result = _as_leaf_node(result_as_tensor, is_repeated, - sparse_tensors[0].dense_shape[0], options) - if result.values.dtype != dtype: - raise ValueError(f"Type unmatched: actual ({str(result.values.dtype)})!= expected ({str(dtype)})") - return result - - return _map_prensor_impl(root, root_path, paths, new_op, is_repeated, dtype, - new_field_name) - - -def _ragged_as_leaf_node(ragged_tensor: tf.RaggedTensor, is_repeated: bool, - reference_ragged_tensor: tf.RaggedTensor, - options: calculate_options.Options - ) -> prensor.LeafNodeTensor: - """Creates a ragged tensor as a leaf node.""" - assertions = [] - size_dim = tf.compat.dimension_at_index(ragged_tensor.shape, 0).value - reference_size_dim = tf.compat.dimension_at_index( - reference_ragged_tensor.shape, 0).value - if (size_dim is not None and reference_size_dim is not None): - if size_dim != reference_size_dim: - raise ValueError("Returned ragged tensor is not the right size.") - elif options.ragged_checks: - assertions.append( - tf.assert_equal(ragged_tensor.nrows(), reference_ragged_tensor.nrows())) - - if not is_repeated: - rowids = ragged_tensor.value_rowids() - if options.ragged_checks: - assertions.append(tf.compat.v1.assert_positive(rowids[1:] - rowids[:-1])) - if assertions: - with tf.control_dependencies(assertions): - parent_index = ragged_tensor.value_rowids() - return prensor.LeafNodeTensor(parent_index, ragged_tensor.values, - is_repeated) - else: - parent_index = ragged_tensor.value_rowids() - return prensor.LeafNodeTensor(parent_index, ragged_tensor.values, - is_repeated) - - -def _map_ragged_tensor_impl(root: expression.Expression, root_path: path.Path, - paths: Sequence[path.Path], - operation: Callable[..., tf.RaggedTensor], - is_repeated: bool, dtype: tf.DType, - new_field_name: path.Step - ) -> Tuple[expression.Expression, path.Path]: - """Maps a ragged tensor. - - Args: - root: the root of the expression. - root_path: the path relative to which the ragged tensors are calculated. - paths: the input paths relative to the root_path - operation: a method that takes the list of ragged tensors as input and - returns a ragged tensor. - is_repeated: true if the result of operation is repeated. - dtype: dtype of the result of the operation. - new_field_name: root_path.get_child(new_field_name) is the path of the - result. - - Returns: - An expression/path pair (expr,p) with a new root expression containing - the old root expression plus the new path, - root_path.get_child(new_field_name), with the result of the operation. - """ - - def new_op(tree: prensor.Prensor, - options: calculate_options.Options) -> prensor.LeafNodeTensor: - """Apply operation to tree.""" - ragged_tensor_map = tree.get_ragged_tensors(options) - ragged_tensors = [ragged_tensor_map[p] for p in paths] - result_as_tensor = operation(*ragged_tensors) - result = _ragged_as_leaf_node(result_as_tensor, is_repeated, - ragged_tensors[0], options) - if result.values.dtype != dtype: - raise ValueError(f"Type unmatched: actual ({str(result.values.dtype)})!= expected ({str(dtype)})") - return result - - return _map_prensor_impl(root, root_path, paths, new_op, is_repeated, dtype, - new_field_name) + operation: Callable[ + [prensor.Prensor, calculate_options.Options], prensor.LeafNodeTensor + ], + is_repeated: bool, + dtype: tf.DType, + new_field_name: path.Step, +) -> Tuple[expression.Expression, path.Path]: + """Map prensor implementation.""" + child_expr = root.get_descendant_or_error(root_path) + sibling_child_expr = project.project(child_expr, paths_needed) + new_field_expr = _MapPrensorExpression( + sibling_child_expr, operation, is_repeated, dtype + ) + new_path = root_path.get_child(new_field_name) + return expression_add.add_paths(root, {new_path: new_field_expr}), new_path + + +def _map_sparse_tensor_impl( + root: expression.Expression, + root_path: path.Path, + paths: Sequence[path.Path], + operation: Callable[..., tf.SparseTensor], + is_repeated: bool, + dtype: tf.DType, + new_field_name: path.Step, +) -> Tuple[expression.Expression, path.Path]: + """Helper method for map_sparse_tensor.""" + + def new_op( + pren: prensor.Prensor, options: calculate_options.Options + ) -> prensor.LeafNodeTensor: + """Op for mapping prensor using the operation.""" + sparse_tensor_map = pren.get_sparse_tensors(options) + sparse_tensors = [sparse_tensor_map[p] for p in paths] + result_as_tensor = operation(*sparse_tensors) + result = _as_leaf_node( + result_as_tensor, is_repeated, sparse_tensors[0].dense_shape[0], options + ) + if result.values.dtype != dtype: + raise ValueError( + f"Type unmatched: actual ({str(result.values.dtype)})!= expected ({str(dtype)})" + ) + return result + + return _map_prensor_impl( + root, root_path, paths, new_op, is_repeated, dtype, new_field_name + ) + + +def _ragged_as_leaf_node( + ragged_tensor: tf.RaggedTensor, + is_repeated: bool, + reference_ragged_tensor: tf.RaggedTensor, + options: calculate_options.Options, +) -> prensor.LeafNodeTensor: + """Creates a ragged tensor as a leaf node.""" + assertions = [] + size_dim = tf.compat.dimension_at_index(ragged_tensor.shape, 0).value + reference_size_dim = tf.compat.dimension_at_index( + reference_ragged_tensor.shape, 0 + ).value + if size_dim is not None and reference_size_dim is not None: + if size_dim != reference_size_dim: + raise ValueError("Returned ragged tensor is not the right size.") + elif options.ragged_checks: + assertions.append( + tf.assert_equal(ragged_tensor.nrows(), reference_ragged_tensor.nrows()) + ) + + if not is_repeated: + rowids = ragged_tensor.value_rowids() + if options.ragged_checks: + assertions.append(tf.compat.v1.assert_positive(rowids[1:] - rowids[:-1])) + if assertions: + with tf.control_dependencies(assertions): + parent_index = ragged_tensor.value_rowids() + return prensor.LeafNodeTensor( + parent_index, ragged_tensor.values, is_repeated + ) + else: + parent_index = ragged_tensor.value_rowids() + return prensor.LeafNodeTensor(parent_index, ragged_tensor.values, is_repeated) + + +def _map_ragged_tensor_impl( + root: expression.Expression, + root_path: path.Path, + paths: Sequence[path.Path], + operation: Callable[..., tf.RaggedTensor], + is_repeated: bool, + dtype: tf.DType, + new_field_name: path.Step, +) -> Tuple[expression.Expression, path.Path]: + """Maps a ragged tensor. + + Args: + root: the root of the expression. + root_path: the path relative to which the ragged tensors are calculated. + paths: the input paths relative to the root_path + operation: a method that takes the list of ragged tensors as input and + returns a ragged tensor. + is_repeated: true if the result of operation is repeated. + dtype: dtype of the result of the operation. + new_field_name: root_path.get_child(new_field_name) is the path of the + result. + + Returns: + An expression/path pair (expr,p) with a new root expression containing + the old root expression plus the new path, + root_path.get_child(new_field_name), with the result of the operation. + """ + + def new_op( + tree: prensor.Prensor, options: calculate_options.Options + ) -> prensor.LeafNodeTensor: + """Apply operation to tree.""" + ragged_tensor_map = tree.get_ragged_tensors(options) + ragged_tensors = [ragged_tensor_map[p] for p in paths] + result_as_tensor = operation(*ragged_tensors) + result = _ragged_as_leaf_node( + result_as_tensor, is_repeated, ragged_tensors[0], options + ) + if result.values.dtype != dtype: + raise ValueError( + f"Type unmatched: actual ({str(result.values.dtype)})!= expected ({str(dtype)})" + ) + return result + + return _map_prensor_impl( + root, root_path, paths, new_op, is_repeated, dtype, new_field_name + ) diff --git a/struct2tensor/expression_impl/map_prensor_test.py b/struct2tensor/expression_impl/map_prensor_test.py index 1b173b4..20d9296 100644 --- a/struct2tensor/expression_impl/map_prensor_test.py +++ b/struct2tensor/expression_impl/map_prensor_test.py @@ -16,7 +16,7 @@ import tensorflow as tf from absl.testing import absltest from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor import calculate_options, create_expression, path, prensor @@ -25,149 +25,199 @@ def _create_one_value_prensor(): - """Creates a prensor expression representing a list of flat protocol buffers. - - Returns: - a RootPrensor representing: - {} - {foo:8} - {} - """ - return prensor.create_prensor_from_descendant_nodes({ - path.Path([]): prensor_test_util.create_root_node(3), - path.Path(["foo"]): prensor_test_util.create_optional_leaf_node([1], [8]) - }) + """Creates a prensor expression representing a list of flat protocol buffers. + + Returns: + a RootPrensor representing: + {} + {foo:8} + {} + """ + return prensor.create_prensor_from_descendant_nodes( + { + path.Path([]): prensor_test_util.create_root_node(3), + path.Path(["foo"]): prensor_test_util.create_optional_leaf_node([1], [8]), + } + ) options_to_test = [ calculate_options.get_default_options(), - calculate_options.get_options_with_minimal_checks() + calculate_options.get_options_with_minimal_checks(), ] @test_util.run_all_in_graph_and_eager_modes class MapPrensorTest(tf.test.TestCase): - - def _test_assert_raises(self, test_runner): - with self.assertRaises(tf.errors.InvalidArgumentError): - test_runner(calculate_options.get_default_options()) - test_runner(calculate_options.get_options_with_minimal_checks()) - - def test_map_sparse_tensor(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_simple_prensor()) - - new_root = map_prensor.map_sparse_tensor(expr, path.Path([]), - [path.Path(["foo"])], - lambda x: x * 2, False, tf.int32, - "foo_doubled") - - leaf_node = expression_test_util.calculate_value_slowly( - new_root.get_descendant_or_error(path.Path(["foo_doubled"]))) - self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) - self.assertAllEqual(leaf_node.values, [18, 16, 14]) - - def test_map_sparse_tensor_one_output(self): - for options in options_to_test: - expr = create_expression.create_expression_from_prensor( - _create_one_value_prensor()) - - new_root = map_prensor.map_sparse_tensor(expr, path.Path( - []), [path.Path(["foo"])], lambda x: x * 2, False, tf.int32, - "foo_doubled") - - leaf_node = expression_test_util.calculate_value_slowly( - new_root.get_descendant_or_error(path.Path(["foo_doubled"])), - options=options) - self.assertAllEqual(leaf_node.parent_index, [1]) - self.assertAllEqual(leaf_node.values, [16]) - - def test_map_sparse_tensor_is_repeated(self): - for options in options_to_test: - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_simple_prensor()) - - new_root = map_prensor.map_sparse_tensor(expr, path.Path([]), - [path.Path(["foorepeated"])], - lambda x: x * 2, True, tf.int32, - "foorepeated_doubled") - - leaf_node = expression_test_util.calculate_value_slowly( - new_root.get_descendant_or_error(path.Path(["foorepeated_doubled"])), - options=options) - self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 2]) - self.assertAllEqual(leaf_node.values, [18, 16, 14, 12]) - - def test_map_sparse_tensor_is_repeated_assert(self): - - def _test_runner(options): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_simple_prensor()) - new_root = map_prensor.map_sparse_tensor(expr, path.Path([]), - [path.Path(["foo"])], - lambda x: x * 2, True, tf.int32, - "foo_doubled") - leaf_node = expression_test_util.calculate_value_slowly( - new_root.get_descendant_or_error(path.Path(["foo_doubled"])), - options=options) - self.evaluate(leaf_node.parent_index) - self.evaluate(leaf_node.values) - - self._test_assert_raises(_test_runner) - - def test_map_sparse_tensor_assert(self): - def _test_runner(options): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_simple_prensor()) - - new_root = map_prensor.map_sparse_tensor(expr, path.Path([]), - [path.Path(["foorepeated"])], - lambda x: x * 2, False, tf.int32, - "foorepeated_doubled") - - leaf_node = expression_test_util.calculate_value_slowly( - new_root.get_descendant_or_error(path.Path(["foorepeated_doubled"])), - options=options) - self.evaluate(leaf_node.parent_index) - self.evaluate(leaf_node.values) - - self._test_assert_raises(_test_runner) - - def test_map_sparse_tensor_assert_batch_size(self): - def _test_runner(options): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_simple_prensor()) - - new_root = map_prensor.map_sparse_tensor( - expr, path.Path([]), [path.Path(["foo"])], - lambda x: tf.sparse.concat(0, [x, x]), False, tf.int32, "foo_concat") - - leaf_node = expression_test_util.calculate_value_slowly( - new_root.get_descendant_or_error(path.Path(["foo_concat"])), - options=options) - self.evaluate(leaf_node.parent_index) - self.evaluate(leaf_node.values) - - self._test_assert_raises(_test_runner) - - - def test_skip_eager_map_ragged_tensor_repeated(self): - # This fails in eager, with an inconsistency in the ragged tensor. - if tf.executing_eagerly(): - return - for options in options_to_test: - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_simple_prensor()) - new_root = map_prensor.map_ragged_tensor(expr, path.Path([]), - [path.Path(["foorepeated"])], - lambda x: x * 2, False, tf.int32, - "foorepeated_doubled") - leaf_node = expression_test_util.calculate_value_slowly( - new_root.get_descendant_or_error(path.Path(["foorepeated_doubled"])), - options=options) - self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 2]) - self.assertAllEqual(leaf_node.values, [18, 16, 14, 12]) + def _test_assert_raises(self, test_runner): + with self.assertRaises(tf.errors.InvalidArgumentError): + test_runner(calculate_options.get_default_options()) + test_runner(calculate_options.get_options_with_minimal_checks()) + + def test_map_sparse_tensor(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_simple_prensor() + ) + + new_root = map_prensor.map_sparse_tensor( + expr, + path.Path([]), + [path.Path(["foo"])], + lambda x: x * 2, + False, + tf.int32, + "foo_doubled", + ) + + leaf_node = expression_test_util.calculate_value_slowly( + new_root.get_descendant_or_error(path.Path(["foo_doubled"])) + ) + self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) + self.assertAllEqual(leaf_node.values, [18, 16, 14]) + + def test_map_sparse_tensor_one_output(self): + for options in options_to_test: + expr = create_expression.create_expression_from_prensor( + _create_one_value_prensor() + ) + + new_root = map_prensor.map_sparse_tensor( + expr, + path.Path([]), + [path.Path(["foo"])], + lambda x: x * 2, + False, + tf.int32, + "foo_doubled", + ) + + leaf_node = expression_test_util.calculate_value_slowly( + new_root.get_descendant_or_error(path.Path(["foo_doubled"])), + options=options, + ) + self.assertAllEqual(leaf_node.parent_index, [1]) + self.assertAllEqual(leaf_node.values, [16]) + + def test_map_sparse_tensor_is_repeated(self): + for options in options_to_test: + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_simple_prensor() + ) + + new_root = map_prensor.map_sparse_tensor( + expr, + path.Path([]), + [path.Path(["foorepeated"])], + lambda x: x * 2, + True, + tf.int32, + "foorepeated_doubled", + ) + + leaf_node = expression_test_util.calculate_value_slowly( + new_root.get_descendant_or_error(path.Path(["foorepeated_doubled"])), + options=options, + ) + self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 2]) + self.assertAllEqual(leaf_node.values, [18, 16, 14, 12]) + + def test_map_sparse_tensor_is_repeated_assert(self): + def _test_runner(options): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_simple_prensor() + ) + new_root = map_prensor.map_sparse_tensor( + expr, + path.Path([]), + [path.Path(["foo"])], + lambda x: x * 2, + True, + tf.int32, + "foo_doubled", + ) + leaf_node = expression_test_util.calculate_value_slowly( + new_root.get_descendant_or_error(path.Path(["foo_doubled"])), + options=options, + ) + self.evaluate(leaf_node.parent_index) + self.evaluate(leaf_node.values) + + self._test_assert_raises(_test_runner) + + def test_map_sparse_tensor_assert(self): + def _test_runner(options): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_simple_prensor() + ) + + new_root = map_prensor.map_sparse_tensor( + expr, + path.Path([]), + [path.Path(["foorepeated"])], + lambda x: x * 2, + False, + tf.int32, + "foorepeated_doubled", + ) + + leaf_node = expression_test_util.calculate_value_slowly( + new_root.get_descendant_or_error(path.Path(["foorepeated_doubled"])), + options=options, + ) + self.evaluate(leaf_node.parent_index) + self.evaluate(leaf_node.values) + + self._test_assert_raises(_test_runner) + + def test_map_sparse_tensor_assert_batch_size(self): + def _test_runner(options): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_simple_prensor() + ) + + new_root = map_prensor.map_sparse_tensor( + expr, + path.Path([]), + [path.Path(["foo"])], + lambda x: tf.sparse.concat(0, [x, x]), + False, + tf.int32, + "foo_concat", + ) + + leaf_node = expression_test_util.calculate_value_slowly( + new_root.get_descendant_or_error(path.Path(["foo_concat"])), + options=options, + ) + self.evaluate(leaf_node.parent_index) + self.evaluate(leaf_node.values) + + self._test_assert_raises(_test_runner) + + def test_skip_eager_map_ragged_tensor_repeated(self): + # This fails in eager, with an inconsistency in the ragged tensor. + if tf.executing_eagerly(): + return + for options in options_to_test: + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_simple_prensor() + ) + new_root = map_prensor.map_ragged_tensor( + expr, + path.Path([]), + [path.Path(["foorepeated"])], + lambda x: x * 2, + False, + tf.int32, + "foorepeated_doubled", + ) + leaf_node = expression_test_util.calculate_value_slowly( + new_root.get_descendant_or_error(path.Path(["foorepeated_doubled"])), + options=options, + ) + self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 2]) + self.assertAllEqual(leaf_node.values, [18, 16, 14, 12]) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/map_prensor_to_prensor.py b/struct2tensor/expression_impl/map_prensor_to_prensor.py index 6f1310e..78ad6da 100644 --- a/struct2tensor/expression_impl/map_prensor_to_prensor.py +++ b/struct2tensor/expression_impl/map_prensor_to_prensor.py @@ -78,128 +78,135 @@ class Schema: - """A finite schema for a prensor. + """A finite schema for a prensor. - Effectively, this stores everything for the prensor but the tensors - themselves. + Effectively, this stores everything for the prensor but the tensors + themselves. - Notice that this is slightly different than schema_pb2.Schema, although - similar in nature. At present, there is no clear way to extract is_repeated - and dtype from schema_pb2.Schema. + Notice that this is slightly different than schema_pb2.Schema, although + similar in nature. At present, there is no clear way to extract is_repeated + and dtype from schema_pb2.Schema. - See create_schema below for constructing a schema. + See create_schema below for constructing a schema. - Note that for LeafNodeTensor, dtype is not None. - Also, for ChildNodeTensor and RootNodeTensor, dtype is None. However, - a ChildNodeTensor or RootNodeTensor could be childless. - """ - - def __init__(self, - is_repeated: bool = True, - dtype: Optional[tf.DType] = None, - schema_feature: Optional[schema_pb2.Feature] = None, - children: Optional[Dict[path.Step, "Schema"]] = None): - """Create a new Schema object. - - Args: - is_repeated: is the root repeated? - dtype: tf.dtype of the root if the root is a leaf, otherwise None. - schema_feature: schema_pb2.Feature of the root (no struct_domain - necessary) - children: child schemas. + Note that for LeafNodeTensor, dtype is not None. + Also, for ChildNodeTensor and RootNodeTensor, dtype is None. However, + a ChildNodeTensor or RootNodeTensor could be childless. """ - self._is_repeated = is_repeated - self._type = dtype - self._schema_feature = schema_feature - self._children = children if children is not None else {} - # Cannot have a type and children. - assert (self._type is None or not self._children) - - @property - def is_repeated(self) -> bool: - return self._is_repeated - - @property - def type(self) -> Optional[tf.DType]: - return self._type - @property - def schema_feature(self) -> Optional[schema_pb2.Feature]: - return self._schema_feature - - def get_child(self, key: path.Step): - return self._children[key] - - def known_field_names(self) -> FrozenSet[path.Step]: - return frozenset(self._children.keys()) - - def __str__(self) -> str: - return (f"Schema(is_repeated:{self._is_repeated} type:{self._type}" + def __init__( + self, + is_repeated: bool = True, + dtype: Optional[tf.DType] = None, + schema_feature: Optional[schema_pb2.Feature] = None, + children: Optional[Dict[path.Step, "Schema"]] = None, + ): + """Create a new Schema object. + + Args: + is_repeated: is the root repeated? + dtype: tf.dtype of the root if the root is a leaf, otherwise None. + schema_feature: schema_pb2.Feature of the root (no struct_domain + necessary) + children: child schemas. + """ + self._is_repeated = is_repeated + self._type = dtype + self._schema_feature = schema_feature + self._children = children if children is not None else {} + # Cannot have a type and children. + assert self._type is None or not self._children + + @property + def is_repeated(self) -> bool: + return self._is_repeated + + @property + def type(self) -> Optional[tf.DType]: + return self._type + + @property + def schema_feature(self) -> Optional[schema_pb2.Feature]: + return self._schema_feature + + def get_child(self, key: path.Step): + return self._children[key] + + def known_field_names(self) -> FrozenSet[path.Step]: + return frozenset(self._children.keys()) + + def __str__(self) -> str: + return ( + f"Schema(is_repeated:{self._is_repeated} type:{self._type}" " " f"schema_feature:({self._schema_feature})" " " - f"children:{self._children})") + f"children:{self._children})" + ) -def create_schema(is_repeated: bool = True, - dtype: Optional[tf.DType] = None, - schema_feature: Optional[schema_pb2.Feature] = None, - children: Optional[Dict[path.Step, Any]] = None) -> Schema: - """Create a schema recursively. +def create_schema( + is_repeated: bool = True, + dtype: Optional[tf.DType] = None, + schema_feature: Optional[schema_pb2.Feature] = None, + children: Optional[Dict[path.Step, Any]] = None, +) -> Schema: + """Create a schema recursively. - Example: - my_result_schema = create_schema( - is_repeated=True, - children={"foo2":{is_repeated=True, dtype=tf.int64}, - "bar2":{is_repeated=False, dtype=tf.int64}}) + Example: + my_result_schema = create_schema( + is_repeated=True, + children={"foo2":{is_repeated=True, dtype=tf.int64}, + "bar2":{is_repeated=False, dtype=tf.int64}}) - Args: - is_repeated: whether the root is repeated. - dtype: the dtype of a leaf (None for non-leaves). - schema_feature: the schema_pb2.Feature describing this expression. name and - struct_domain need not be specified. - children: the child schemas. Note that the value type of children is either - a Schema or a dictionary of arguments to create_schema. + Args: + is_repeated: whether the root is repeated. + dtype: the dtype of a leaf (None for non-leaves). + schema_feature: the schema_pb2.Feature describing this expression. name and + struct_domain need not be specified. + children: the child schemas. Note that the value type of children is either + a Schema or a dictionary of arguments to create_schema. - Returns: - a new Schema represented by the inputs. - """ - children_dict = children or {} - child_schemas = { - k: _create_schema_helper(v) for k, v in children_dict.items() - } - return Schema( - is_repeated=is_repeated, - dtype=dtype, - schema_feature=schema_feature, - children=child_schemas) + Returns: + a new Schema represented by the inputs. + """ + children_dict = children or {} + child_schemas = {k: _create_schema_helper(v) for k, v in children_dict.items()} + return Schema( + is_repeated=is_repeated, + dtype=dtype, + schema_feature=schema_feature, + children=child_schemas, + ) def _create_schema_helper(my_input) -> Schema: - """Helper for create_schema. + """Helper for create_schema. - If my_input is a Schema, it is just returned. + If my_input is a Schema, it is just returned. - Otherwise, my_input should be a dictionary of arguments to create_schema. + Otherwise, my_input should be a dictionary of arguments to create_schema. - Args: - my_input: either a Schema or a dictionary which optionally has the keys - is_repeated, dtype, and children. + Args: + my_input: either a Schema or a dictionary which optionally has the keys + is_repeated, dtype, and children. - Returns: - a Schema. - """ - if isinstance(my_input, Schema): - return my_input - return create_schema(**my_input) + Returns: + a Schema. + """ + if isinstance(my_input, Schema): + return my_input + return create_schema(**my_input) def map_prensor_to_prensor( - root_expr: expression.Expression, source: path.Path, + root_expr: expression.Expression, + source: path.Path, paths_needed: Sequence[path.Path], prensor_op: Callable[[prensor.Prensor], prensor.Prensor], - output_schema: Schema) -> expression.Expression: - r"""Maps an expression to a prensor, and merges that prensor. + output_schema: Schema, +) -> expression.Expression: + r"""Maps an expression to a prensor, and merges that prensor. For example, suppose you have an op my_op, that takes a prensor of the form: @@ -245,232 +252,237 @@ def map_prensor_to_prensor( Returns: A new expression where the prensor is merged. """ - original_child = root_expr.get_descendant_or_error(source).project( - paths_needed) - prensor_child = _PrensorOpExpression(original_child, prensor_op, - output_schema) - paths_map = { - source.get_child(k): prensor_child.get_child_or_error(k) - for k in prensor_child.known_field_names() - } - return expression_add.add_paths(root_expr, paths_map) + original_child = root_expr.get_descendant_or_error(source).project(paths_needed) + prensor_child = _PrensorOpExpression(original_child, prensor_op, output_schema) + paths_map = { + source.get_child(k): prensor_child.get_child_or_error(k) + for k in prensor_child.known_field_names() + } + return expression_add.add_paths(root_expr, paths_map) ##################### Implementation Follows ################################### def _tree_as_node(prensor_tree: prensor.Prensor) -> "_TreeAsNode": - """Create a _TreeAsNode, a NodeTensor with a prensor property. + """Create a _TreeAsNode, a NodeTensor with a prensor property. - The root node of the tree is pulled out, the resulting NodeTensor has the same - properties. + The root node of the tree is pulled out, the resulting NodeTensor has the same + properties. - Args: - prensor_tree: the original tree, that will be accessible via prensor. + Args: + prensor_tree: the original tree, that will be accessible via prensor. - Returns: - a _TreeAsNode x with x.prensor==prensor_tree, and x is equivalent - to prensor_tree.node. - """ - top_node = prensor_tree.node - if isinstance(top_node, prensor.RootNodeTensor): - return _PrensorAsRootNodeTensor(prensor_tree, top_node) - if isinstance(top_node, prensor.ChildNodeTensor): - return _PrensorAsChildNodeTensor(prensor_tree, top_node) - return _PrensorAsLeafNodeTensor(prensor_tree, top_node) + Returns: + a _TreeAsNode x with x.prensor==prensor_tree, and x is equivalent + to prensor_tree.node. + """ + top_node = prensor_tree.node + if isinstance(top_node, prensor.RootNodeTensor): + return _PrensorAsRootNodeTensor(prensor_tree, top_node) + if isinstance(top_node, prensor.ChildNodeTensor): + return _PrensorAsChildNodeTensor(prensor_tree, top_node) + return _PrensorAsLeafNodeTensor(prensor_tree, top_node) class _PrensorAsRootNodeTensor(prensor.RootNodeTensor): - """A root node tensor that has a prensor property.""" + """A root node tensor that has a prensor property.""" - def __init__(self, prensor_tree: prensor.Prensor, - root: prensor.RootNodeTensor): - """Call _tree_as_node instead.""" - super().__init__(root.size) - self._prensor = prensor_tree + def __init__(self, prensor_tree: prensor.Prensor, root: prensor.RootNodeTensor): + """Call _tree_as_node instead.""" + super().__init__(root.size) + self._prensor = prensor_tree - @property - def prensor(self): - return self._prensor + @property + def prensor(self): + return self._prensor class _PrensorAsChildNodeTensor(prensor.ChildNodeTensor): - """A child node tensor that has a prensor property.""" + """A child node tensor that has a prensor property.""" - def __init__(self, prensor_tree: prensor.Prensor, - child: prensor.ChildNodeTensor): - """Call _tree_as_node instead.""" - super().__init__(child.parent_index, child.is_repeated) - self._prensor = prensor_tree + def __init__(self, prensor_tree: prensor.Prensor, child: prensor.ChildNodeTensor): + """Call _tree_as_node instead.""" + super().__init__(child.parent_index, child.is_repeated) + self._prensor = prensor_tree - @property - def prensor(self): - return self._prensor + @property + def prensor(self): + return self._prensor class _PrensorAsLeafNodeTensor(prensor.LeafNodeTensor): - """A leaf node tensor that has a prensor property.""" + """A leaf node tensor that has a prensor property.""" - def __init__(self, prensor_tree: prensor.Prensor, - leaf: prensor.LeafNodeTensor): - """Call _tree_as_node instead.""" - super().__init__(leaf.parent_index, leaf.values, leaf.is_repeated) - self._prensor = prensor_tree + def __init__(self, prensor_tree: prensor.Prensor, leaf: prensor.LeafNodeTensor): + """Call _tree_as_node instead.""" + super().__init__(leaf.parent_index, leaf.values, leaf.is_repeated) + self._prensor = prensor_tree - @property - def prensor(self): - return self._prensor + @property + def prensor(self): + return self._prensor -_TreeAsNode = Union[_PrensorAsLeafNodeTensor, _PrensorAsChildNodeTensor, - _PrensorAsRootNodeTensor] +_TreeAsNode = Union[ + _PrensorAsLeafNodeTensor, _PrensorAsChildNodeTensor, _PrensorAsRootNodeTensor +] -def _get_schema_or_error(parent: expression.Expression, - step: path.Step) -> Schema: - if not isinstance(parent, (_PrensorOpExpression, _PrensorOpChildExpression)): - raise ValueError("No parent schema") - parent_schema = parent.schema - return parent_schema.get_child(step) +def _get_schema_or_error(parent: expression.Expression, step: path.Step) -> Schema: + if not isinstance(parent, (_PrensorOpExpression, _PrensorOpChildExpression)): + raise ValueError("No parent schema") + parent_schema = parent.schema + return parent_schema.get_child(step) class _PrensorOpChildExpression(expression.Expression): - """A helper class for PrensorOpExpression, representing its descendants.""" - - def __init__(self, parent: expression.Expression, step: path.Step, - schema: Schema): - super().__init__( - schema.is_repeated, - schema.type, - schema_feature=schema.schema_feature, - validate_step_format=parent.validate_step_format, - ) - self._parent = parent - self._step = step - self._schema = schema - - @property - def schema(self): - return self._schema - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._parent] - - def calculate(self, - source_tensors: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> _TreeAsNode: - [parent_result] = source_tensors - if not isinstance(parent_result, - (_PrensorAsLeafNodeTensor, _PrensorAsChildNodeTensor, - _PrensorAsRootNodeTensor)): - raise ValueError("Parent did not return _TreeAsNode") - my_prensor = parent_result.prensor.get_child(self._step) - return _tree_as_node(my_prensor) - - def calculation_is_identity(self) -> bool: - return False - - def calculation_equal(self, expr: expression.Expression) -> bool: - if isinstance(expr, _PrensorOpChildExpression): - return self._step == expr._step # pylint: disable=protected-access - return False - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - """Implementation of getting a named child in a subclass. - - This is called and cached inside get_child(). - - Args: - field_name: the field accessed. - - Returns: - The expression of the field, or None if the field does not exist. - """ - if field_name not in self._schema.known_field_names(): - return None - child_schema = self._schema.get_child(field_name) - return _PrensorOpChildExpression(self, field_name, child_schema) - - def known_field_names(self) -> FrozenSet[path.Step]: - return self._schema.known_field_names() + """A helper class for PrensorOpExpression, representing its descendants.""" + + def __init__(self, parent: expression.Expression, step: path.Step, schema: Schema): + super().__init__( + schema.is_repeated, + schema.type, + schema_feature=schema.schema_feature, + validate_step_format=parent.validate_step_format, + ) + self._parent = parent + self._step = step + self._schema = schema + + @property + def schema(self): + return self._schema + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._parent] + + def calculate( + self, + source_tensors: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> _TreeAsNode: + [parent_result] = source_tensors + if not isinstance( + parent_result, + ( + _PrensorAsLeafNodeTensor, + _PrensorAsChildNodeTensor, + _PrensorAsRootNodeTensor, + ), + ): + raise ValueError("Parent did not return _TreeAsNode") + my_prensor = parent_result.prensor.get_child(self._step) + return _tree_as_node(my_prensor) + + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + if isinstance(expr, _PrensorOpChildExpression): + return self._step == expr._step # pylint: disable=protected-access + return False + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + """Implementation of getting a named child in a subclass. + + This is called and cached inside get_child(). + + Args: + field_name: the field accessed. + + Returns: + The expression of the field, or None if the field does not exist. + """ + if field_name not in self._schema.known_field_names(): + return None + child_schema = self._schema.get_child(field_name) + return _PrensorOpChildExpression(self, field_name, child_schema) + + def known_field_names(self) -> FrozenSet[path.Step]: + return self._schema.known_field_names() class _PrensorOpExpression(expression.Expression): - """An expression generated by a callable that returns a prensor. - - Note that it is expected that project() is called on the origin before - this method, as is generally advised before calling get_known_descendants(). - """ - - def __init__(self, origin: expression.Expression, - operation: Callable[[prensor.Prensor], prensor.Prensor], - schema: Schema): - """Creates a new expression. - - Conceptually, this expression = operation applied to origin, where the - result is tagged with schema. + """An expression generated by a callable that returns a prensor. - Args: - origin: the expression which is used to generate the input prensor. - operation: the operation applied to the origin. - schema: the schema of the result. If a path is not in the schema, it is - not calculated. + Note that it is expected that project() is called on the origin before + this method, as is generally advised before calling get_known_descendants(). """ - super().__init__( - schema.is_repeated, - schema.type, - schema_feature=schema.schema_feature, - validate_step_format=origin.validate_step_format, - ) - self._origin = origin - self._operation = operation - self._schema = schema - - @property - def schema(self): - return self._schema - - def _get_source_paths(self) -> Sequence[path.Path]: - """Returns the source paths in a deterministic order.""" - result = [k for k in self._origin.get_known_descendants()] - # In order to make certain that the source_paths are in a deterministic - # order, we sort them here. - result.sort() - return result - - def get_source_expressions(self) -> Sequence[expression.Expression]: - subtree = self._origin.get_known_descendants() - source_paths = self._get_source_paths() - return [subtree[k] for k in source_paths] - - def calculate(self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> _TreeAsNode: - source_tree = prensor.create_prensor_from_descendant_nodes( - {k: v for k, v in zip(self._get_source_paths(), sources)}) - result_tree = self._operation(source_tree) - # TODO(martinz): consider a full type check on result_tree. This - # can be done outside of the GraphDef. - return _tree_as_node(result_tree) - - def calculation_is_identity(self) -> bool: - return False - - def calculation_equal(self, expr: expression.Expression) -> bool: - return self is expr - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - if field_name not in self._schema.known_field_names(): - return None - child_schema = self._schema.get_child(field_name) - return _PrensorOpChildExpression(self, field_name, child_schema) - - def known_field_names(self) -> FrozenSet[path.Step]: - return self._schema.known_field_names() + def __init__( + self, + origin: expression.Expression, + operation: Callable[[prensor.Prensor], prensor.Prensor], + schema: Schema, + ): + """Creates a new expression. + + Conceptually, this expression = operation applied to origin, where the + result is tagged with schema. + + Args: + origin: the expression which is used to generate the input prensor. + operation: the operation applied to the origin. + schema: the schema of the result. If a path is not in the schema, it is + not calculated. + """ + super().__init__( + schema.is_repeated, + schema.type, + schema_feature=schema.schema_feature, + validate_step_format=origin.validate_step_format, + ) + + self._origin = origin + self._operation = operation + self._schema = schema + + @property + def schema(self): + return self._schema + + def _get_source_paths(self) -> Sequence[path.Path]: + """Returns the source paths in a deterministic order.""" + result = [k for k in self._origin.get_known_descendants()] + # In order to make certain that the source_paths are in a deterministic + # order, we sort them here. + result.sort() + return result + + def get_source_expressions(self) -> Sequence[expression.Expression]: + subtree = self._origin.get_known_descendants() + source_paths = self._get_source_paths() + return [subtree[k] for k in source_paths] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> _TreeAsNode: + source_tree = prensor.create_prensor_from_descendant_nodes( + {k: v for k, v in zip(self._get_source_paths(), sources)} + ) + result_tree = self._operation(source_tree) + # TODO(martinz): consider a full type check on result_tree. This + # can be done outside of the GraphDef. + return _tree_as_node(result_tree) + + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + return self is expr + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + if field_name not in self._schema.known_field_names(): + return None + child_schema = self._schema.get_child(field_name) + return _PrensorOpChildExpression(self, field_name, child_schema) + + def known_field_names(self) -> FrozenSet[path.Step]: + return self._schema.known_field_names() diff --git a/struct2tensor/expression_impl/map_prensor_to_prensor_test.py b/struct2tensor/expression_impl/map_prensor_to_prensor_test.py index b7fd2c7..43272d3 100644 --- a/struct2tensor/expression_impl/map_prensor_to_prensor_test.py +++ b/struct2tensor/expression_impl/map_prensor_to_prensor_test.py @@ -16,16 +16,16 @@ import tensorflow as tf from absl.testing import absltest from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) from tensorflow_metadata.proto.v0 import schema_pb2 # For tf.Session.Run against a Prensor from struct2tensor import ( - calculate, - create_expression, - path, - prensor, # pylint: disable=unused-import + calculate, + create_expression, + path, + prensor, # pylint: disable=unused-import ) from struct2tensor.expression_impl import map_prensor_to_prensor from struct2tensor.test import prensor_test_util @@ -33,123 +33,126 @@ @test_util.run_all_in_graph_and_eager_modes class MapPrensorToPrensorTest(tf.test.TestCase): - - def test_map_prensor_to_prensor(self): - original = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor()) - - def my_prensor_op(original_prensor): - # Note that we are copying over the original root prensor node. The root - # node is ignored in the result. - return prensor.create_prensor_from_descendant_nodes({ - path.Path([]): - original_prensor.node, - path.Path(["bar2"]): - original_prensor.get_child_or_error("bar").node, - path.Path(["keep_me2"]): - original_prensor.get_child_or_error("keep_me").node - }) - - # Since the top node is actually a child node, we use the child schema. - my_output_schema = map_prensor_to_prensor.create_schema( - is_repeated=True, - children={ - "bar2": { - "is_repeated": True, - "dtype": tf.string + def test_map_prensor_to_prensor(self): + original = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor() + ) + + def my_prensor_op(original_prensor): + # Note that we are copying over the original root prensor node. The root + # node is ignored in the result. + return prensor.create_prensor_from_descendant_nodes( + { + path.Path([]): original_prensor.node, + path.Path(["bar2"]): original_prensor.get_child_or_error( + "bar" + ).node, + path.Path(["keep_me2"]): original_prensor.get_child_or_error( + "keep_me" + ).node, + } + ) + + # Since the top node is actually a child node, we use the child schema. + my_output_schema = map_prensor_to_prensor.create_schema( + is_repeated=True, + children={ + "bar2": {"is_repeated": True, "dtype": tf.string}, + "keep_me2": {"is_repeated": False, "dtype": tf.bool}, }, - "keep_me2": { - "is_repeated": False, - "dtype": tf.bool - } - }) - - result = map_prensor_to_prensor.map_prensor_to_prensor( - root_expr=original, - source=path.Path(["doc"]), - paths_needed=[path.Path(["bar"]), - path.Path(["keep_me"])], - prensor_op=my_prensor_op, - output_schema=my_output_schema) - - doc_result = result.get_child_or_error("doc") - bar_result = doc_result.get_child_or_error("bar") - keep_me_result = doc_result.get_child_or_error("keep_me") - bar2_result = doc_result.get_child_or_error("bar2") - keep_me2_result = doc_result.get_child_or_error("keep_me2") - self.assertIsNone(doc_result.get_child("missing_field")) - self.assertTrue(bar_result.is_repeated) - self.assertTrue(bar2_result.is_repeated) - self.assertEqual(bar_result.type, tf.string) - self.assertEqual(bar2_result.type, tf.string) - self.assertFalse(keep_me_result.is_repeated) - self.assertFalse(keep_me2_result.is_repeated) - self.assertEqual(keep_me_result.type, tf.bool) - self.assertEqual(keep_me2_result.type, tf.bool) - - [prensor_result] = calculate.calculate_prensors([result]) - - doc_value = prensor_result.get_child_or_error("doc") - self.assertAllEqual([0, 1, 1], doc_value.node.parent_index) - bar2_value = doc_value.get_child_or_error("bar2") - self.assertAllEqual([0, 1, 1, 2], bar2_value.node.parent_index) - self.assertAllEqual([b"a", b"b", b"c", b"d"], bar2_value.node.values) - keep_me2_value = doc_value.get_child_or_error("keep_me2") - self.assertAllEqual([0, 1], keep_me2_value.node.parent_index) - self.assertAllEqual([False, True], keep_me2_value.node.values) - - def test_map_prensor_to_prensor_with_schema(self): - original = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor()) - - def my_prensor_op(original_prensor): - # Note that we are copying over the original root prensor node. The root - # node is ignored in the result. - return prensor.create_prensor_from_descendant_nodes({ - path.Path([]): - original_prensor.node, - path.Path(["bar2"]): - original_prensor.get_child_or_error("bar").node, - path.Path(["keep_me2"]): - original_prensor.get_child_or_error("keep_me").node - }) - - bar2_feature = schema_pb2.Feature() - bar2_feature.value_count.max = 7 - keep_me2_feature = schema_pb2.Feature() - keep_me2_feature.value_count.max = 10 - - # Since the top node is actually a child node, we use the child schema. - my_output_schema = map_prensor_to_prensor.create_schema( - is_repeated=True, - children={ - "bar2": { - "is_repeated": True, - "dtype": tf.string, - "schema_feature": bar2_feature + ) + + result = map_prensor_to_prensor.map_prensor_to_prensor( + root_expr=original, + source=path.Path(["doc"]), + paths_needed=[path.Path(["bar"]), path.Path(["keep_me"])], + prensor_op=my_prensor_op, + output_schema=my_output_schema, + ) + + doc_result = result.get_child_or_error("doc") + bar_result = doc_result.get_child_or_error("bar") + keep_me_result = doc_result.get_child_or_error("keep_me") + bar2_result = doc_result.get_child_or_error("bar2") + keep_me2_result = doc_result.get_child_or_error("keep_me2") + self.assertIsNone(doc_result.get_child("missing_field")) + self.assertTrue(bar_result.is_repeated) + self.assertTrue(bar2_result.is_repeated) + self.assertEqual(bar_result.type, tf.string) + self.assertEqual(bar2_result.type, tf.string) + self.assertFalse(keep_me_result.is_repeated) + self.assertFalse(keep_me2_result.is_repeated) + self.assertEqual(keep_me_result.type, tf.bool) + self.assertEqual(keep_me2_result.type, tf.bool) + + [prensor_result] = calculate.calculate_prensors([result]) + + doc_value = prensor_result.get_child_or_error("doc") + self.assertAllEqual([0, 1, 1], doc_value.node.parent_index) + bar2_value = doc_value.get_child_or_error("bar2") + self.assertAllEqual([0, 1, 1, 2], bar2_value.node.parent_index) + self.assertAllEqual([b"a", b"b", b"c", b"d"], bar2_value.node.values) + keep_me2_value = doc_value.get_child_or_error("keep_me2") + self.assertAllEqual([0, 1], keep_me2_value.node.parent_index) + self.assertAllEqual([False, True], keep_me2_value.node.values) + + def test_map_prensor_to_prensor_with_schema(self): + original = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor() + ) + + def my_prensor_op(original_prensor): + # Note that we are copying over the original root prensor node. The root + # node is ignored in the result. + return prensor.create_prensor_from_descendant_nodes( + { + path.Path([]): original_prensor.node, + path.Path(["bar2"]): original_prensor.get_child_or_error( + "bar" + ).node, + path.Path(["keep_me2"]): original_prensor.get_child_or_error( + "keep_me" + ).node, + } + ) + + bar2_feature = schema_pb2.Feature() + bar2_feature.value_count.max = 7 + keep_me2_feature = schema_pb2.Feature() + keep_me2_feature.value_count.max = 10 + + # Since the top node is actually a child node, we use the child schema. + my_output_schema = map_prensor_to_prensor.create_schema( + is_repeated=True, + children={ + "bar2": { + "is_repeated": True, + "dtype": tf.string, + "schema_feature": bar2_feature, + }, + "keep_me2": { + "is_repeated": False, + "dtype": tf.bool, + "schema_feature": keep_me2_feature, + }, }, - "keep_me2": { - "is_repeated": False, - "dtype": tf.bool, - "schema_feature": keep_me2_feature - } - }) + ) - result = map_prensor_to_prensor.map_prensor_to_prensor( - root_expr=original, - source=path.Path(["doc"]), - paths_needed=[path.Path(["bar"]), - path.Path(["keep_me"])], - prensor_op=my_prensor_op, - output_schema=my_output_schema) + result = map_prensor_to_prensor.map_prensor_to_prensor( + root_expr=original, + source=path.Path(["doc"]), + paths_needed=[path.Path(["bar"]), path.Path(["keep_me"])], + prensor_op=my_prensor_op, + output_schema=my_output_schema, + ) - doc_result = result.get_child_or_error("doc") - bar2_result = doc_result.get_child_or_error("bar2") - self.assertEqual(bar2_result.schema_feature.value_count.max, 7) + doc_result = result.get_child_or_error("doc") + bar2_result = doc_result.get_child_or_error("bar2") + self.assertEqual(bar2_result.schema_feature.value_count.max, 7) - keep_me2_result = doc_result.get_child_or_error("keep_me2") - self.assertEqual(keep_me2_result.schema_feature.value_count.max, 10) + keep_me2_result = doc_result.get_child_or_error("keep_me2") + self.assertEqual(keep_me2_result.schema_feature.value_count.max, 10) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/map_values.py b/struct2tensor/expression_impl/map_values.py index a781ffb..d6f2368 100644 --- a/struct2tensor/expression_impl/map_values.py +++ b/struct2tensor/expression_impl/map_values.py @@ -29,132 +29,162 @@ def map_many_values( - root: expression.Expression, parent_path: path.Path, - source_fields: Sequence[path.Step], operation: Callable[..., tf.Tensor], + root: expression.Expression, + parent_path: path.Path, + source_fields: Sequence[path.Step], + operation: Callable[..., tf.Tensor], dtype: tf.DType, - new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: - """Map multiple sibling fields into a new sibling. - - All source fields must have the same shape, and the shape of the output - must be the same as well. - - Args: - root: original root. - parent_path: parent path of all sources and the new field. - source_fields: source fields of the operation. Must have the same shape. - operation: operation from source_fields to new field. - dtype: type of new field. - new_field_name: name of the new field. - - Returns: - The new expression and the new path as a pair. - """ - new_path = parent_path.get_child(new_field_name) - return expression_add.add_paths( - root, { - new_path: - _MapValuesExpression([ - root.get_descendant_or_error(parent_path.get_child(f)) - for f in source_fields - ], operation, dtype) - }), new_path + new_field_name: path.Step, +) -> Tuple[expression.Expression, path.Path]: + """Map multiple sibling fields into a new sibling. + + All source fields must have the same shape, and the shape of the output + must be the same as well. + + Args: + root: original root. + parent_path: parent path of all sources and the new field. + source_fields: source fields of the operation. Must have the same shape. + operation: operation from source_fields to new field. + dtype: type of new field. + new_field_name: name of the new field. + + Returns: + The new expression and the new path as a pair. + """ + new_path = parent_path.get_child(new_field_name) + return expression_add.add_paths( + root, + { + new_path: _MapValuesExpression( + [ + root.get_descendant_or_error(parent_path.get_child(f)) + for f in source_fields + ], + operation, + dtype, + ) + }, + ), new_path def map_values_anonymous( - root: expression.Expression, source_path: path.Path, + root: expression.Expression, + source_path: path.Path, operation: Callable[[tf.Tensor], tf.Tensor], - dtype: tf.DType) -> Tuple[expression.Expression, path.Path]: - """Map field into a new sibling. - - The shape of the output must be the same as the input. - - Args: - root: original root. - source_path: source of the operation. - operation: operation from source_fields to new field. - dtype: type of new field. - - Returns: - The new expression and the new path as a pair. - """ - if not source_path: - raise ValueError('Cannot map the root.') - return map_many_values(root, source_path.get_parent(), - [source_path.field_list[-1]], operation, dtype, - path.get_anonymous_field()) - - -def map_values(root: expression.Expression, source_path: path.Path, - operation: Callable[[tf.Tensor], tf.Tensor], dtype: tf.DType, - new_field_name: path.Step) -> expression.Expression: - """Map field into a new sibling. - - The shape of the output must be the same as the input. - - Args: - root: original root. - source_path: source of the operation. - operation: operation from source_fields to new field. - dtype: type of new field. - new_field_name: name of the new field. - - Returns: - The new expression. - """ - if not source_path: - raise ValueError('Cannot map the root.') - return map_many_values(root, source_path.get_parent(), - [source_path.field_list[-1]], operation, dtype, - new_field_name)[0] + dtype: tf.DType, +) -> Tuple[expression.Expression, path.Path]: + """Map field into a new sibling. + + The shape of the output must be the same as the input. + + Args: + root: original root. + source_path: source of the operation. + operation: operation from source_fields to new field. + dtype: type of new field. + + Returns: + The new expression and the new path as a pair. + """ + if not source_path: + raise ValueError("Cannot map the root.") + return map_many_values( + root, + source_path.get_parent(), + [source_path.field_list[-1]], + operation, + dtype, + path.get_anonymous_field(), + ) + + +def map_values( + root: expression.Expression, + source_path: path.Path, + operation: Callable[[tf.Tensor], tf.Tensor], + dtype: tf.DType, + new_field_name: path.Step, +) -> expression.Expression: + """Map field into a new sibling. + + The shape of the output must be the same as the input. + + Args: + root: original root. + source_path: source of the operation. + operation: operation from source_fields to new field. + dtype: type of new field. + new_field_name: name of the new field. + + Returns: + The new expression. + """ + if not source_path: + raise ValueError("Cannot map the root.") + return map_many_values( + root, + source_path.get_parent(), + [source_path.field_list[-1]], + operation, + dtype, + new_field_name, + )[0] def _leaf_node_or_error(node: prensor.NodeTensor) -> prensor.LeafNodeTensor: - if isinstance(node, prensor.LeafNodeTensor): - return node - raise ValueError(f'node is {str(type(node))} not LeafNodeTensor') + if isinstance(node, prensor.LeafNodeTensor): + return node + raise ValueError(f"node is {str(type(node))} not LeafNodeTensor") class _MapValuesExpression(expression.Expression): - """Map the values of the given expression. - - _MapValuesExpression is intended to be a sibling of origin. - The operation should return a tensor that is the same size as its input. - """ - - def __init__(self, origin: Sequence[expression.Expression], - operation: Callable[..., tf.Tensor], dtype: tf.DType): - super().__init__(origin[0].is_repeated, dtype) - assert all([self.is_repeated == x.is_repeated for x in origin]) - self._origin = origin - self._operation = operation - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return self._origin - - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - source_leaves = [_leaf_node_or_error(s) for s in sources] - source_values = [s.values for s in source_leaves] - # TODO(martinz): Check that: - # source_values have equal parent_index. - # output_value has the same size as the input. - return prensor.LeafNodeTensor(source_leaves[0].parent_index, - self._operation(*source_values), - self._is_repeated) - - def calculation_is_identity(self) -> bool: - return False - - def calculation_equal(self, expr: expression.Expression) -> bool: - return self is expr - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - return None - - def known_field_names(self) -> FrozenSet[path.Step]: - return frozenset() + """Map the values of the given expression. + + _MapValuesExpression is intended to be a sibling of origin. + The operation should return a tensor that is the same size as its input. + """ + + def __init__( + self, + origin: Sequence[expression.Expression], + operation: Callable[..., tf.Tensor], + dtype: tf.DType, + ): + super().__init__(origin[0].is_repeated, dtype) + assert all([self.is_repeated == x.is_repeated for x in origin]) + self._origin = origin + self._operation = operation + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return self._origin + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + source_leaves = [_leaf_node_or_error(s) for s in sources] + source_values = [s.values for s in source_leaves] + # TODO(martinz): Check that: + # source_values have equal parent_index. + # output_value has the same size as the input. + return prensor.LeafNodeTensor( + source_leaves[0].parent_index, + self._operation(*source_values), + self._is_repeated, + ) + + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + return self is expr + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + return None + + def known_field_names(self) -> FrozenSet[path.Step]: + return frozenset() diff --git a/struct2tensor/expression_impl/map_values_test.py b/struct2tensor/expression_impl/map_values_test.py index 69afd71..95d5c3b 100644 --- a/struct2tensor/expression_impl/map_values_test.py +++ b/struct2tensor/expression_impl/map_values_test.py @@ -16,7 +16,7 @@ import tensorflow as tf from absl.testing import absltest from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor import create_expression, path, prensor @@ -26,53 +26,66 @@ @test_util.run_all_in_graph_and_eager_modes class MapValuesTest(tf.test.TestCase): + def test_map_values_anonymous(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_simple_prensor() + ) - def test_map_values_anonymous(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_simple_prensor()) + new_root, p = map_values.map_values_anonymous( + expr, path.Path(["foo"]), lambda x: x * 2, tf.int64 + ) - new_root, p = map_values.map_values_anonymous(expr, path.Path(["foo"]), - lambda x: x * 2, tf.int64) + leaf_node = expression_test_util.calculate_value_slowly( + new_root.get_descendant_or_error(p) + ) + self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) + self.assertAllEqual(leaf_node.values, [18, 16, 14]) - leaf_node = expression_test_util.calculate_value_slowly( - new_root.get_descendant_or_error(p)) - self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) - self.assertAllEqual(leaf_node.values, [18, 16, 14]) + def test_map_values(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_simple_prensor() + ) - def test_map_values(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_simple_prensor()) + new_root = map_values.map_values( + expr, path.Path(["foo"]), lambda x: x * 2, tf.int64, "foo_doubled" + ) - new_root = map_values.map_values(expr, path.Path(["foo"]), lambda x: x * 2, - tf.int64, "foo_doubled") + leaf_node = expression_test_util.calculate_value_slowly( + new_root.get_descendant_or_error(path.Path(["foo_doubled"])) + ) + self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) + self.assertAllEqual(leaf_node.values, [18, 16, 14]) - leaf_node = expression_test_util.calculate_value_slowly( - new_root.get_descendant_or_error(path.Path(["foo_doubled"]))) - self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) - self.assertAllEqual(leaf_node.values, [18, 16, 14]) + def test_map_many_values(self): + expr = create_expression.create_expression_from_prensor( + prensor.create_prensor_from_descendant_nodes( + { + path.Path([]): prensor_test_util.create_root_node(3), + path.Path(["foo"]): prensor_test_util.create_optional_leaf_node( + [0, 2, 3], [9, 8, 7] + ), + path.Path(["bar"]): prensor_test_util.create_optional_leaf_node( + [0, 2, 3], [10, 20, 30] + ), + } + ) + ) - def test_map_many_values(self): - expr = create_expression.create_expression_from_prensor( - prensor.create_prensor_from_descendant_nodes({ - path.Path([]): - prensor_test_util.create_root_node(3), - path.Path(["foo"]): - prensor_test_util.create_optional_leaf_node([0, 2, 3], - [9, 8, 7]), - path.Path(["bar"]): - prensor_test_util.create_optional_leaf_node([0, 2, 3], - [10, 20, 30]) - })) + new_root, p = map_values.map_many_values( + expr, + path.Path([]), + ["foo", "bar"], + lambda x, y: x + y, + tf.int64, + "new_field", + ) - new_root, p = map_values.map_many_values(expr, path.Path([]), - ["foo", "bar"], lambda x, y: x + y, - tf.int64, "new_field") - - leaf_node = expression_test_util.calculate_value_slowly( - new_root.get_descendant_or_error(p)) - self.assertAllEqual(leaf_node.parent_index, [0, 2, 3]) - self.assertAllEqual(leaf_node.values, [19, 28, 37]) + leaf_node = expression_test_util.calculate_value_slowly( + new_root.get_descendant_or_error(p) + ) + self.assertAllEqual(leaf_node.parent_index, [0, 2, 3]) + self.assertAllEqual(leaf_node.values, [19, 28, 37]) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/parquet.py b/struct2tensor/expression_impl/parquet.py index 256387c..d3ccc53 100644 --- a/struct2tensor/expression_impl/parquet.py +++ b/struct2tensor/expression_impl/parquet.py @@ -41,27 +41,30 @@ def create_expression_from_parquet_file( - filenames: List[str]) -> placeholder._PlaceholderRootExpression: # pylint: disable=protected-access - """Creates a placeholder expression from a parquet file. + filenames: List[str], +) -> placeholder._PlaceholderRootExpression: # pylint: disable=protected-access + """Creates a placeholder expression from a parquet file. - Args: - filenames: A list of parquet files. + Args: + filenames: A list of parquet files. - Returns: - A PlaceholderRootExpression that should be used as the root of an expression - graph. - """ - metadata = pq.ParquetFile(filenames[0]).metadata - parquet_schema = metadata.schema - arrow_schema = parquet_schema.to_arrow_schema() + Returns: + A PlaceholderRootExpression that should be used as the root of an expression + graph. + """ + metadata = pq.ParquetFile(filenames[0]).metadata + parquet_schema = metadata.schema + arrow_schema = parquet_schema.to_arrow_schema() - root_schema = mpp.create_schema( - is_repeated=True, - children=_create_children_from_arrow_fields( - [arrow_schema.field_by_name(name) for name in arrow_schema.names])) + root_schema = mpp.create_schema( + is_repeated=True, + children=_create_children_from_arrow_fields( + [arrow_schema.field_by_name(name) for name in arrow_schema.names] + ), + ) - # pylint: disable=protected-access - return placeholder._PlaceholderRootExpression(root_schema) + # pylint: disable=protected-access + return placeholder._PlaceholderRootExpression(root_schema) def calculate_parquet_values( @@ -69,466 +72,491 @@ def calculate_parquet_values( root_exp: placeholder._PlaceholderRootExpression, # pylint: disable=protected-access filenames: List[str], batch_size: int, - options: Optional[calculate_options.Options] = None): - """Calculates expressions and returns a parquet dataset. - - Args: - expressions: A list of expressions to calculate. - root_exp: The root placeholder expression to use as the feed dict. - filenames: A list of parquet files. - batch_size: The number of messages to batch. - options: calculate options. - - Returns: - A parquet dataset. - """ - pqds = _ParquetDatasetWithExpression(expressions, root_exp, filenames, - batch_size, options) - return pqds.map(pqds._calculate_prensor) # pylint: disable=protected-access - - -class _RawParquetDataset(tf.compat.v1.data.Dataset): - """A dataset which reads columns from parquet and outputs a vector of tensors. - - A ParquetDataset is a Dataset of batches of messages (records). - Every leaf field field of the messages in each batch has its own values tensor - and parent indices tensors (which encodes the structural information). - - The user has control over which parent indices of which fields in a path to - read, and is determined by parent_index_paths and path_index. - - View //struct2tensor/ops/parquet_dataset_op.cc - for a better understanding of what format the vector of tensors is in. - """ - - def __init__(self, filenames: List[str], value_paths: List[str], - value_dtypes: List[tf.DType], parent_index_paths: List[str], - path_index: List[int], batch_size: int): - """Creates a ParquetDataset. - - Args: - filenames: A list containing the name(s) of the file(s) to be read. - value_paths: A list of strings of the dotstring path(s) of each leaf - path(s). - value_dtypes: value_dtypes[i] is the Tensorflow data type value_paths[i] - would be of. - parent_index_paths: A list of strings of the dotstring path(s) of the - path(s) to be read. - path_index: A list containing the index of each field to get the parent - index of. This will have the same length as parent_index_paths. - batch_size: An int that determines how many messages are parsed into one - prensor tree in an iteration. If there are fewer than batch_size - remaining messages, then all remaining messages will be returned. - - Raises: - ValueError: if the column does not exist in the parquet schema. - ValueError: if the column dtype does not match the value_dtype passed in. - """ - self._filenames = filenames - self._value_paths = value_paths - self._value_dtypes = tuple(value_dtypes) - self._parent_index_paths = parent_index_paths - self._path_index = path_index - self._batch_size = batch_size - - super().__init__() - - def _get_column_path_to_index_mapping(self, metadata_file) -> Dict[str, int]: - """Gets the column index of every column. + options: Optional[calculate_options.Options] = None, +): + """Calculates expressions and returns a parquet dataset. Args: - metadata_file: the file to be used as the metadata. If there is no - metadata_file, any file from file_names will suffice. + expressions: A list of expressions to calculate. + root_exp: The root placeholder expression to use as the feed dict. + filenames: A list of parquet files. + batch_size: The number of messages to batch. + options: calculate options. Returns: - A dictionary mapping path name (str) to column index (int). + A parquet dataset. """ - metadata = pq.ParquetFile(metadata_file).metadata + pqds = _ParquetDatasetWithExpression( + expressions, root_exp, filenames, batch_size, options + ) + return pqds.map(pqds._calculate_prensor) # pylint: disable=protected-access - return { - metadata.schema.column(index).path: index - for index in range(metadata.num_columns) - } +class _RawParquetDataset(tf.compat.v1.data.Dataset): + """A dataset which reads columns from parquet and outputs a vector of tensors. - def _parquet_to_tf_type(self, parquet_type: str) -> Union[tf.DType, None]: - """Maps tensorflow datatype to a parquet datatype. - - Args: - parquet_type: a string representing the parquet datatype. - - Returns: - the tensorflow datatype equivalent of a parquet datatype. - """ - return { - "BOOLEAN": tf.bool, - "INT32": tf.int32, - "INT64": tf.int64, - "FLOAT": tf.float32, - "DOUBLE": tf.double, - "BYTE_ARRAY": tf.string - }.get(parquet_type) - - def _as_variant_tensor(self): - return gen_parquet_dataset.parquet_dataset( - self._filenames, - value_paths=self._value_paths, - value_dtypes=self._value_dtypes, - parent_index_paths=self._parent_index_paths, - path_index=self._path_index, - batch_size=self._batch_size) - - def _inputs(self): - return [] - - @property - def output_types(self): - res = [] - column_counter = 0 - prev = self._parent_index_paths[0] - res.append(tf.int64) - for i in range(1, len(self._parent_index_paths)): - curr = self._parent_index_paths[i] - res.append(tf.int64) - if curr != prev: - res.append(self._value_dtypes[column_counter]) - column_counter += 1 - prev = curr - res.append(tf.int64) - res.append(self._value_dtypes[column_counter]) - self.output_dtypes = tuple(res) - return self.output_dtypes - - @property - def output_shapes(self): - return (tf.TensorShape([]),) + tuple( - tf.TensorShape([None]) for i in range(1, len(self.output_dtypes))) - - @property - def output_classes(self): - return tuple(tf.Tensor for i in range(len(self.output_dtypes))) - - -class ParquetDataset(_RawParquetDataset): - """A dataset which reads columns from a parquet file and returns a prensor. - - The prensor will have a PrensorTypeSpec, which is created based on - value_paths. - - Note: In tensorflow v1 this dataset will not return a prensor. The output will - be the same format as _RawParquetDataset's output (a vector of tensors). - The following is a workaround in v1: - pq_ds = ParquetDataset(...) - type_spec = pq_ds.element_spec - tensors = pq_ds.make_one_shot_iterator().get_next() - prensor = type_spec.from_components(tensors) - session.run(prensor) - """ + A ParquetDataset is a Dataset of batches of messages (records). + Every leaf field field of the messages in each batch has its own values tensor + and parent indices tensors (which encodes the structural information). - def __init__(self, filenames: List[str], value_paths: List[str], - batch_size: int): - """Creates a ParquetDataset. + The user has control over which parent indices of which fields in a path to + read, and is determined by parent_index_paths and path_index. - Args: - filenames: A list containing the name(s) of the file(s) to be read. - value_paths: A list of strings of the dotstring path(s) of each leaf - path(s). - batch_size: An int that determines how many messages are parsed into one - prensor tree in an iteration. If there are fewer than batch_size - remaining messages, then all remaining messages will be returned. - - Raises: - ValueError: if the column does not exist in the parquet schema. + View //struct2tensor/ops/parquet_dataset_op.cc + for a better understanding of what format the vector of tensors is in. """ - self._filenames = filenames - self._value_paths = value_paths - self._batch_size = batch_size - - for filename in filenames: - self._validate_file(filename, value_paths) - - self._value_dtypes = self._get_column_dtypes(filenames[0], value_paths) - - self._parent_index_paths = [] - self._path_index = [] - - self.element_structure = self._create_prensor_spec() - self._create_parent_index_paths_and_index_from_type_spec( - self.element_structure, 0, 0) - - super().__init__(filenames, self._value_paths, self._value_dtypes, - self._parent_index_paths, self._path_index, batch_size) - def _get_column_dtypes( - self, metadata_file: str, - value_paths: List[str]) -> List[Union[tf.DType, None]]: - """Returns a list of tensorflow datatypes for each column. + def __init__( + self, + filenames: List[str], + value_paths: List[str], + value_dtypes: List[tf.DType], + parent_index_paths: List[str], + path_index: List[int], + batch_size: int, + ): + """Creates a ParquetDataset. + + Args: + filenames: A list containing the name(s) of the file(s) to be read. + value_paths: A list of strings of the dotstring path(s) of each leaf + path(s). + value_dtypes: value_dtypes[i] is the Tensorflow data type value_paths[i] + would be of. + parent_index_paths: A list of strings of the dotstring path(s) of the + path(s) to be read. + path_index: A list containing the index of each field to get the parent + index of. This will have the same length as parent_index_paths. + batch_size: An int that determines how many messages are parsed into one + prensor tree in an iteration. If there are fewer than batch_size + remaining messages, then all remaining messages will be returned. + + Raises: + ValueError: if the column does not exist in the parquet schema. + ValueError: if the column dtype does not match the value_dtype passed in. + """ + self._filenames = filenames + self._value_paths = value_paths + self._value_dtypes = tuple(value_dtypes) + self._parent_index_paths = parent_index_paths + self._path_index = path_index + self._batch_size = batch_size + + super().__init__() + + def _get_column_path_to_index_mapping(self, metadata_file) -> Dict[str, int]: + """Gets the column index of every column. + + Args: + metadata_file: the file to be used as the metadata. If there is no + metadata_file, any file from file_names will suffice. + + Returns: + A dictionary mapping path name (str) to column index (int). + """ + metadata = pq.ParquetFile(metadata_file).metadata + + return { + metadata.schema.column(index).path: index + for index in range(metadata.num_columns) + } - Args: - metadata_file: the file to be used as the metadata. If there is no - metadata_file, any file from file_names will suffice. - value_paths: A list of strings of the dotstring path(s). + def _parquet_to_tf_type(self, parquet_type: str) -> Union[tf.DType, None]: + """Maps tensorflow datatype to a parquet datatype. + + Args: + parquet_type: a string representing the parquet datatype. + + Returns: + the tensorflow datatype equivalent of a parquet datatype. + """ + return { + "BOOLEAN": tf.bool, + "INT32": tf.int32, + "INT64": tf.int64, + "FLOAT": tf.float32, + "DOUBLE": tf.double, + "BYTE_ARRAY": tf.string, + }.get(parquet_type) + + def _as_variant_tensor(self): + return gen_parquet_dataset.parquet_dataset( + self._filenames, + value_paths=self._value_paths, + value_dtypes=self._value_dtypes, + parent_index_paths=self._parent_index_paths, + path_index=self._path_index, + batch_size=self._batch_size, + ) + + def _inputs(self): + return [] + + @property + def output_types(self): + res = [] + column_counter = 0 + prev = self._parent_index_paths[0] + res.append(tf.int64) + for i in range(1, len(self._parent_index_paths)): + curr = self._parent_index_paths[i] + res.append(tf.int64) + if curr != prev: + res.append(self._value_dtypes[column_counter]) + column_counter += 1 + prev = curr + res.append(tf.int64) + res.append(self._value_dtypes[column_counter]) + self.output_dtypes = tuple(res) + return self.output_dtypes - Returns: - A list of tensorflow datatypes for each column. This list aligns with - value_paths. - """ - path_to_column_index = self._get_column_path_to_index_mapping(metadata_file) - metadata = pq.ParquetFile(metadata_file).metadata + @property + def output_shapes(self): + return (tf.TensorShape([]),) + tuple( + tf.TensorShape([None]) for i in range(1, len(self.output_dtypes)) + ) - value_dtypes = [] - for column in value_paths: - col = metadata.schema.column(path_to_column_index[column]) - parquet_type = col.physical_type - value_dtypes.append(self._parquet_to_tf_type(parquet_type)) - return value_dtypes + @property + def output_classes(self): + return tuple(tf.Tensor for i in range(len(self.output_dtypes))) - def _validate_file(self, filename: str, value_paths: List[str]): - """Checks if each requested path exists in the parquet file. - Args: - filename: The parquet filename. - value_paths: A list of strings of the dotstring path(s). - - Raises: - ValueError: if a path does not exist in the parquet file's schema. +class ParquetDataset(_RawParquetDataset): + """A dataset which reads columns from a parquet file and returns a prensor. + + The prensor will have a PrensorTypeSpec, which is created based on + value_paths. + + Note: In tensorflow v1 this dataset will not return a prensor. The output will + be the same format as _RawParquetDataset's output (a vector of tensors). + The following is a workaround in v1: + pq_ds = ParquetDataset(...) + type_spec = pq_ds.element_spec + tensors = pq_ds.make_one_shot_iterator().get_next() + prensor = type_spec.from_components(tensors) + session.run(prensor) """ - metadata = pq.ParquetFile(filename).metadata - paths = {} - for i in range(metadata.num_columns): - col = metadata.schema.column(i) - p = (col.path) - paths[p] = col.physical_type + def __init__(self, filenames: List[str], value_paths: List[str], batch_size: int): + """Creates a ParquetDataset. + + Args: + filenames: A list containing the name(s) of the file(s) to be read. + value_paths: A list of strings of the dotstring path(s) of each leaf + path(s). + batch_size: An int that determines how many messages are parsed into one + prensor tree in an iteration. If there are fewer than batch_size + remaining messages, then all remaining messages will be returned. + + Raises: + ValueError: if the column does not exist in the parquet schema. + """ + self._filenames = filenames + self._value_paths = value_paths + self._batch_size = batch_size + + for filename in filenames: + self._validate_file(filename, value_paths) + + self._value_dtypes = self._get_column_dtypes(filenames[0], value_paths) + + self._parent_index_paths = [] + self._path_index = [] + + self.element_structure = self._create_prensor_spec() + self._create_parent_index_paths_and_index_from_type_spec( + self.element_structure, 0, 0 + ) + + super().__init__( + filenames, + self._value_paths, + self._value_dtypes, + self._parent_index_paths, + self._path_index, + batch_size, + ) + + def _get_column_dtypes( + self, metadata_file: str, value_paths: List[str] + ) -> List[Union[tf.DType, None]]: + """Returns a list of tensorflow datatypes for each column. + + Args: + metadata_file: the file to be used as the metadata. If there is no + metadata_file, any file from file_names will suffice. + value_paths: A list of strings of the dotstring path(s). + + Returns: + A list of tensorflow datatypes for each column. This list aligns with + value_paths. + """ + path_to_column_index = self._get_column_path_to_index_mapping(metadata_file) + metadata = pq.ParquetFile(metadata_file).metadata + + value_dtypes = [] + for column in value_paths: + col = metadata.schema.column(path_to_column_index[column]) + parquet_type = col.physical_type + value_dtypes.append(self._parquet_to_tf_type(parquet_type)) + return value_dtypes + + def _validate_file(self, filename: str, value_paths: List[str]): + """Checks if each requested path exists in the parquet file. + + Args: + filename: The parquet filename. + value_paths: A list of strings of the dotstring path(s). + + Raises: + ValueError: if a path does not exist in the parquet file's schema. + """ + metadata = pq.ParquetFile(filename).metadata + + paths = {} + for i in range(metadata.num_columns): + col = metadata.schema.column(i) + p = col.path + paths[p] = col.physical_type + + for p in value_paths: + if p not in paths: + raise ValueError("path " + p + " does not exist in the file.") + + def _create_children_spec( + self, field: pa.lib.Field, index_and_paths: List[Tuple[int, List[path.Step]]] + ) -> Tuple[path.Step, prensor._PrensorTypeSpec]: + """Creates the _PrensorTypeSpec for children and leaves. + + Args: + field: a pyarrow field. + index_and_paths: a list of tuple(index, list[step]), where index is the + column index this step belongs to, and list[step] are children steps of + the passed in step arg. The reason index is needed is because we need to + keep track of which column this step belongs to, to populate + parent_index_paths and path_index. + + Returns: + a child or leaf _PrensorTypeSpec. + """ + # pylint: disable=protected-access + curr_steps_as_set = collections.OrderedDict() + # Construct the dictionary of paths we need. + if len(index_and_paths) >= 1 and len(index_and_paths[0][1]) >= 1: + for p in index_and_paths: + index = p[0] + p = p[1] + curr_step = p[0] + if p: + if curr_step in curr_steps_as_set: + curr_steps_as_set[curr_step].append((index, p[1:])) + else: + curr_steps_as_set[curr_step] = [(index, p[1:])] + + field_type = field.type + if isinstance(field_type, pa.lib.ListType): + field_type = field_type.value_type + is_repeated = True + else: + is_repeated = False + if isinstance(field_type, pa.lib.StructType): + node_type = prensor._PrensorTypeSpec._NodeType.CHILD + dtype = tf.int64 + children = [ + self._create_children_spec(field_type[step], curr_steps_as_set[step]) + for step in curr_steps_as_set + ] + else: + node_type = prensor._PrensorTypeSpec._NodeType.LEAF + dtype = tf.dtypes.as_dtype(field_type) + children = [] + + return ( + field.name, + prensor._PrensorTypeSpec(is_repeated, node_type, dtype, children), + ) + + def _create_prensor_spec(self) -> prensor._PrensorTypeSpec: # pylint: disable=protected-access + """Creates the prensor type spec based on value_paths. + + Returns: + a root _PrensorTypeSpec. + """ + metadata = pq.ParquetFile(self._filenames[0]).metadata + parquet_schema = metadata.schema + arrow_schema = parquet_schema.to_arrow_schema() + + # pylint: disable=protected-access + # Sort the paths by number of fields. + paths = [path.create_path(p) for p in self._value_paths] + mapped = zip(paths, self._value_paths, self._value_dtypes) + sorted_mapped = sorted(mapped, key=lambda x: len(x[0].field_list)) + paths, self._value_paths, self._value_dtypes = zip(*sorted_mapped) + + # Creates an ordered dictionary mapping step to a list of children fields. + # This will allow us to find paths that share a parent. + curr_steps_as_set = collections.OrderedDict() + for i, p in enumerate(paths): + step = p.field_list[0] + if step in curr_steps_as_set: + curr_steps_as_set[step].append((i, p.field_list[1:])) + else: + curr_steps_as_set[step] = [(i, p.field_list[1:])] + + return prensor._PrensorTypeSpec( + None, + prensor._PrensorTypeSpec._NodeType.ROOT, + tf.int64, + [ + self._create_children_spec( + arrow_schema.field(step), curr_steps_as_set[step] + ) + for step in curr_steps_as_set + ], + ) + + def _create_parent_index_paths_and_index_from_type_spec( + self, type_spec, index, level + ): + """Populates self._parent_index_paths and self.path_index from the typespec. + + It traverses the prensor type spec to get index and level. It then uses + index to get the correct path from self._value_paths. + + This assumes that self._value_paths is sorted alphabetically, and thus the + prensor type spec has the same order of paths as self._value_paths. + + Args: + type_spec: A Prensor type spec. + index: The index of self._value_paths. It is incremented each time we + reach a leaf, ie we have a new path. + level: the step number in a path. It is incremented each time we go to a + spec's child. It is then decremented when exiting the child spec. + """ + fields = type_spec._children_specs # pylint: disable=protected-access + + for field_tuple in fields: + spec = field_tuple[1] + self._parent_index_paths.append(self._value_paths[index]) + self._path_index.append(level) + level += 1 + self._create_parent_index_paths_and_index_from_type_spec(spec, index, level) + level -= 1 + index += 1 + + @property + def element_spec(self): + return self.element_structure - for p in value_paths: - if p not in paths: - raise ValueError("path " + p + " does not exist in the file.") - def _create_children_spec( - self, field: pa.lib.Field, index_and_paths: List[Tuple[int, - List[path.Step]]] - ) -> Tuple[path.Step, prensor._PrensorTypeSpec]: - """Creates the _PrensorTypeSpec for children and leaves. +def _create_children_from_arrow_fields( + fields: pa.lib.Field, +) -> Dict[str, Dict[Any, Any]]: + """Creates a dictionary of children schema for a pyarrow field. Args: - field: a pyarrow field. - index_and_paths: a list of tuple(index, list[step]), where index is the - column index this step belongs to, and list[step] are children steps of - the passed in step arg. The reason index is needed is because we need to - keep track of which column this step belongs to, to populate - parent_index_paths and path_index. + fields: A list of pyarrow fields. Returns: - a child or leaf _PrensorTypeSpec. + A dictionary of children. Key is field name. Value is a dictionary + representing a schema. """ - # pylint: disable=protected-access - curr_steps_as_set = collections.OrderedDict() - # Construct the dictionary of paths we need. - if len(index_and_paths) >= 1 and len(index_and_paths[0][1]) >= 1: - for p in index_and_paths: - index = p[0] - p = p[1] - curr_step = p[0] - if p: - if curr_step in curr_steps_as_set: - curr_steps_as_set[curr_step].append((index, p[1:])) - else: - curr_steps_as_set[curr_step] = [(index, p[1:])] - - field_type = field.type - if isinstance(field_type, pa.lib.ListType): - field_type = field_type.value_type - is_repeated = True - else: - is_repeated = False - if isinstance(field_type, pa.lib.StructType): - node_type = prensor._PrensorTypeSpec._NodeType.CHILD - dtype = tf.int64 - children = [ - self._create_children_spec(field_type[step], curr_steps_as_set[step]) - for step in curr_steps_as_set - ] - else: - node_type = prensor._PrensorTypeSpec._NodeType.LEAF - dtype = tf.dtypes.as_dtype(field_type) - children = [] - - return (field.name, - prensor._PrensorTypeSpec(is_repeated, node_type, dtype, children)) - - def _create_prensor_spec(self) -> prensor._PrensorTypeSpec: # pylint: disable=protected-access - """Creates the prensor type spec based on value_paths. - - Returns: - a root _PrensorTypeSpec. - """ - metadata = pq.ParquetFile(self._filenames[0]).metadata - parquet_schema = metadata.schema - arrow_schema = parquet_schema.to_arrow_schema() - - # pylint: disable=protected-access - # Sort the paths by number of fields. - paths = [path.create_path(p) for p in self._value_paths] - mapped = zip(paths, self._value_paths, self._value_dtypes) - sorted_mapped = sorted(mapped, key=lambda x: len(x[0].field_list)) - paths, self._value_paths, self._value_dtypes = zip(*sorted_mapped) - - # Creates an ordered dictionary mapping step to a list of children fields. - # This will allow us to find paths that share a parent. - curr_steps_as_set = collections.OrderedDict() - for (i, p) in enumerate(paths): - step = p.field_list[0] - if step in curr_steps_as_set: - curr_steps_as_set[step].append((i, p.field_list[1:])) - else: - curr_steps_as_set[step] = [(i, p.field_list[1:])] - - return prensor._PrensorTypeSpec( - None, prensor._PrensorTypeSpec._NodeType.ROOT, tf.int64, [ - self._create_children_spec( - arrow_schema.field(step), curr_steps_as_set[step]) - for step in curr_steps_as_set - ]) - - def _create_parent_index_paths_and_index_from_type_spec( - self, type_spec, index, level): - """Populates self._parent_index_paths and self.path_index from the typespec. - - It traverses the prensor type spec to get index and level. It then uses - index to get the correct path from self._value_paths. - - This assumes that self._value_paths is sorted alphabetically, and thus the - prensor type spec has the same order of paths as self._value_paths. - - Args: - type_spec: A Prensor type spec. - index: The index of self._value_paths. It is incremented each time we - reach a leaf, ie we have a new path. - level: the step number in a path. It is incremented each time we go to a - spec's child. It is then decremented when exiting the child spec. - """ - fields = type_spec._children_specs # pylint: disable=protected-access - - for field_tuple in fields: - spec = field_tuple[1] - self._parent_index_paths.append(self._value_paths[index]) - self._path_index.append(level) - level += 1 - self._create_parent_index_paths_and_index_from_type_spec( - spec, index, level) - level -= 1 - index += 1 - - @property - def element_spec(self): - return self.element_structure - - -def _create_children_from_arrow_fields( - fields: pa.lib.Field) -> Dict[str, Dict[Any, Any]]: - """Creates a dictionary of children schema for a pyarrow field. - - Args: - fields: A list of pyarrow fields. - - Returns: - A dictionary of children. Key is field name. Value is a dictionary - representing a schema. - """ - children = {} - for field in fields: - field_type = field.type - if isinstance(field_type, pa.lib.ListType): - sub_field_type = field_type.value_type - if isinstance(sub_field_type, pa.lib.StructType): - children[field.name] = { - "is_repeated": - True, - "children": - _create_children_from_arrow_fields( - [subfield for subfield in sub_field_type]) - } - elif isinstance(sub_field_type, pa.lib.DataType): - children[field.name] = { - "is_repeated": True, - "dtype": tf.dtypes.as_dtype(sub_field_type) - } - else: - print("this should never be printed") - elif isinstance(field_type, pa.lib.StructType): - children[field.name] = { - "is_repeated": - False, - "children": - _create_children_from_arrow_fields( - [subfield for subfield in field_type]) - } - else: - children[field.name] = { - "is_repeated": False, - "dtype": tf.dtypes.as_dtype(field_type) - } - return children + children = {} + for field in fields: + field_type = field.type + if isinstance(field_type, pa.lib.ListType): + sub_field_type = field_type.value_type + if isinstance(sub_field_type, pa.lib.StructType): + children[field.name] = { + "is_repeated": True, + "children": _create_children_from_arrow_fields( + [subfield for subfield in sub_field_type] + ), + } + elif isinstance(sub_field_type, pa.lib.DataType): + children[field.name] = { + "is_repeated": True, + "dtype": tf.dtypes.as_dtype(sub_field_type), + } + else: + print("this should never be printed") + elif isinstance(field_type, pa.lib.StructType): + children[field.name] = { + "is_repeated": False, + "children": _create_children_from_arrow_fields( + [subfield for subfield in field_type] + ), + } + else: + children[field.name] = { + "is_repeated": False, + "dtype": tf.dtypes.as_dtype(field_type), + } + return children class _ParquetDatasetWithExpression(ParquetDataset): - """A dataset which reads columns from a parquet file based on the expressions. - - The data read from the parquet file will then have the expression queries - applied to it, creating a new prensor. - - This dataset should not be created by the user, call - parquet_dataset.calculate_parquet_values() to get this dataset instead. - """ - - def __init__(self, exprs: List[expression.Expression], - root_expr: placeholder._PlaceholderRootExpression, - filenames: List[str], batch_size: int, - options: Optional[calculate_options.Options]): - self._exprs = exprs - self._root_expr = root_expr - self._filesnames = filenames - self._batch_size = batch_size - self._options = options - - # pylint: disable=protected-access - self._subtrees = [x.get_known_descendants() for x in self._exprs] - self._all_expressions = [] - for tree in self._subtrees: - self._all_expressions.extend(tree.values()) - - expression_graph = calculate.OriginalExpressionGraph(self._all_expressions) - self._canonical_graph = calculate.CanonicalExpressionGraph(expression_graph) - paths = placeholder.get_placeholder_paths_from_graph(self._canonical_graph) - - parquet_paths = [".".join(p.field_list) for p in paths] - - super().__init__(filenames, parquet_paths, batch_size) - - def _calculate_prensor(self, pren) -> List[prensor.Prensor]: - """Function for applying expression queries to a prensor. + """A dataset which reads columns from a parquet file based on the expressions. - This function should be passed into dataset.map(). + The data read from the parquet file will then have the expression queries + applied to it, creating a new prensor. - Args: - pren: The prensor that will be used to bind to the root expression. - - Returns: - A list of modified prensor that have the expression queries applied. + This dataset should not be created by the user, call + parquet_dataset.calculate_parquet_values() to get this dataset instead. """ - self._canonical_graph.calculate_values( - options=self._options, feed_dict={self._root_expr: pren}) - values = [ - self._canonical_graph.get_value_or_die(x) for x in self._all_expressions - ] - expr_to_value_map = { - id(expr): value for expr, value in zip(self._all_expressions, values) - } + def __init__( + self, + exprs: List[expression.Expression], + root_expr: placeholder._PlaceholderRootExpression, + filenames: List[str], + batch_size: int, + options: Optional[calculate_options.Options], + ): + self._exprs = exprs + self._root_expr = root_expr + self._filesnames = filenames + self._batch_size = batch_size + self._options = options + + # pylint: disable=protected-access + self._subtrees = [x.get_known_descendants() for x in self._exprs] + self._all_expressions = [] + for tree in self._subtrees: + self._all_expressions.extend(tree.values()) + + expression_graph = calculate.OriginalExpressionGraph(self._all_expressions) + self._canonical_graph = calculate.CanonicalExpressionGraph(expression_graph) + paths = placeholder.get_placeholder_paths_from_graph(self._canonical_graph) + + parquet_paths = [".".join(p.field_list) for p in paths] + + super().__init__(filenames, parquet_paths, batch_size) + + def _calculate_prensor(self, pren) -> List[prensor.Prensor]: + """Function for applying expression queries to a prensor. + + This function should be passed into dataset.map(). + + Args: + pren: The prensor that will be used to bind to the root expression. + + Returns: + A list of modified prensor that have the expression queries applied. + """ + self._canonical_graph.calculate_values( + options=self._options, feed_dict={self._root_expr: pren} + ) + values = [ + self._canonical_graph.get_value_or_die(x) for x in self._all_expressions + ] + + expr_to_value_map = { + id(expr): value for expr, value in zip(self._all_expressions, values) + } - # pylint: disable=protected-access - return [ - calculate._get_prensor(subtree, expr_to_value_map) - for subtree in self._subtrees - ] + # pylint: disable=protected-access + return [ + calculate._get_prensor(subtree, expr_to_value_map) + for subtree in self._subtrees + ] diff --git a/struct2tensor/expression_impl/parquet_test.py b/struct2tensor/expression_impl/parquet_test.py index 4a44c9a..bf655ae 100644 --- a/struct2tensor/expression_impl/parquet_test.py +++ b/struct2tensor/expression_impl/parquet_test.py @@ -17,7 +17,7 @@ from absl.testing import absltest from pyarrow.lib import ArrowIOError from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor import path, prensor @@ -28,161 +28,175 @@ @test_util.run_all_in_graph_and_eager_modes class ParquetDatasetTestBase(tf.test.TestCase): - - def setUp(self): - super().setUp() - self._test_filenames = [ - "struct2tensor/testdata/parquet_testdata/dremel_example.parquet" - ] - self._rowgroup_test_filenames = [ - "struct2tensor/testdata/parquet_testdata/dremel_example_two_row_groups.parquet" - ] - - def _assertPrensorEqual(self, result, expected): - """Traverses prensors level order, to check that the two prensors are equal. - - Args: - result: the resulting prensor. - expected: the expected prensor - """ - res_node = result.node - exp_node = expected.node - - self.assertEqual(type(res_node), type(exp_node)) - - if result.is_leaf: - self.assertAllEqual(res_node.parent_index, exp_node.parent_index) - self.assertAllEqual(res_node.values, exp_node.values) - else: - if isinstance(res_node, prensor.RootNodeTensor): - self.assertEqual(res_node.size, exp_node.size) - else: - self.assertAllEqual(res_node.parent_index, exp_node.parent_index) - res_children = result.get_children() - exp_children = expected.get_children() - - for child_step in res_children: - self._assertPrensorEqual(res_children[child_step], - exp_children[child_step]) + def setUp(self): + super().setUp() + self._test_filenames = [ + "struct2tensor/testdata/parquet_testdata/dremel_example.parquet" + ] + self._rowgroup_test_filenames = [ + "struct2tensor/testdata/parquet_testdata/dremel_example_two_row_groups.parquet" + ] + + def _assertPrensorEqual(self, result, expected): + """Traverses prensors level order, to check that the two prensors are equal. + + Args: + result: the resulting prensor. + expected: the expected prensor + """ + res_node = result.node + exp_node = expected.node + + self.assertEqual(type(res_node), type(exp_node)) + + if result.is_leaf: + self.assertAllEqual(res_node.parent_index, exp_node.parent_index) + self.assertAllEqual(res_node.values, exp_node.values) + else: + if isinstance(res_node, prensor.RootNodeTensor): + self.assertEqual(res_node.size, exp_node.size) + else: + self.assertAllEqual(res_node.parent_index, exp_node.parent_index) + res_children = result.get_children() + exp_children = expected.get_children() + + for child_step in res_children: + self._assertPrensorEqual( + res_children[child_step], exp_children[child_step] + ) class ParquetDatasetOutputsPrensorTest(ParquetDatasetTestBase): - """This tests the public facing API, ParquetDataset().""" - - def testInvalidFile(self): - """Tests exception is thrown for invalid file.""" - with self.assertRaisesRegex(ArrowIOError, "Failed to open"): - parquet.ParquetDataset( - filenames=["invalid"], value_paths=["DocId"], batch_size=1) - - def testInvalidColumnName(self): - """Tests exception is thrown for invalid column name.""" - with self.assertRaisesRegex(ValueError, "path does not exist in the file."): - parquet.ParquetDataset( - filenames=self._test_filenames, - value_paths=["invalid_path"], - batch_size=1) - - def testPrensorOutput(self): - """Tests that the dataset outputs a prensor.""" - pq_ds = parquet.ParquetDataset( - self._test_filenames, - value_paths=["Name.Language.Country", "DocId", "Name.Language.Code"], - batch_size=1) - - for (i, pren) in enumerate(pq_ds): - doc_id = pren.get_descendant_or_error(path.Path(["DocId"])).node - self.assertAllEqual(doc_id.parent_index, [0]) - self.assertAllEqual(doc_id.values, [(i + 1) * 10]) - code = pren.get_descendant_or_error( - path.Path(["Name", "Language", "Code"])).node - if i == 0: - self.assertAllEqual(code.parent_index, [0, 1, 2]) - self.assertAllEqual(code.values, [b"en-us", b"en", b"en-gb"]) - else: - self.assertAllEqual(code.parent_index, []) - self.assertAllEqual(code.values, []) - - def testMultipleColumnsTwoRowGroupsAndEqualBatchSize_OutputsPrensor(self): - """Tests that the correct prensor for three columns is outputted.""" - pq_ds = parquet.ParquetDataset( - filenames=self._rowgroup_test_filenames, - value_paths=["DocId", "Name.Language.Code", "Name.Language.Country"], - batch_size=2) - expected_prensor = prensor.create_prensor_from_descendant_nodes({ - path.Path([]): - prensor.RootNodeTensor(tf.constant(2, dtype=tf.int64)), - path.Path(["DocId"]): - prensor.LeafNodeTensor( - tf.constant([0, 1], dtype=tf.int64), - tf.constant([10, 20], dtype=tf.int64), True), - path.Path(["Name"]): - prensor.ChildNodeTensor( - tf.constant([0, 0, 0, 1], dtype=tf.int64), True), - path.Path(["Name", "Language"]): - prensor.ChildNodeTensor( - tf.constant([0, 0, 2], dtype=tf.int64), True), - path.Path(["Name", "Language", "Code"]): - prensor.LeafNodeTensor( - tf.constant([0, 1, 2], dtype=tf.int64), - tf.constant([b"en-us", b"en", b"en-gb"]), True), - path.Path(["Name", "Language", "Country"]): - prensor.LeafNodeTensor( - tf.constant([0, 2], dtype=tf.int64), tf.constant([b"us", - b"gb"]), True) - }) - - for i, pren in enumerate(pq_ds): - if i == 0: - self._assertPrensorEqual(pren, expected_prensor) + """This tests the public facing API, ParquetDataset().""" + + def testInvalidFile(self): + """Tests exception is thrown for invalid file.""" + with self.assertRaisesRegex(ArrowIOError, "Failed to open"): + parquet.ParquetDataset( + filenames=["invalid"], value_paths=["DocId"], batch_size=1 + ) + + def testInvalidColumnName(self): + """Tests exception is thrown for invalid column name.""" + with self.assertRaisesRegex(ValueError, "path does not exist in the file."): + parquet.ParquetDataset( + filenames=self._test_filenames, + value_paths=["invalid_path"], + batch_size=1, + ) + + def testPrensorOutput(self): + """Tests that the dataset outputs a prensor.""" + pq_ds = parquet.ParquetDataset( + self._test_filenames, + value_paths=["Name.Language.Country", "DocId", "Name.Language.Code"], + batch_size=1, + ) + + for i, pren in enumerate(pq_ds): + doc_id = pren.get_descendant_or_error(path.Path(["DocId"])).node + self.assertAllEqual(doc_id.parent_index, [0]) + self.assertAllEqual(doc_id.values, [(i + 1) * 10]) + code = pren.get_descendant_or_error( + path.Path(["Name", "Language", "Code"]) + ).node + if i == 0: + self.assertAllEqual(code.parent_index, [0, 1, 2]) + self.assertAllEqual(code.values, [b"en-us", b"en", b"en-gb"]) + else: + self.assertAllEqual(code.parent_index, []) + self.assertAllEqual(code.values, []) + + def testMultipleColumnsTwoRowGroupsAndEqualBatchSize_OutputsPrensor(self): + """Tests that the correct prensor for three columns is outputted.""" + pq_ds = parquet.ParquetDataset( + filenames=self._rowgroup_test_filenames, + value_paths=["DocId", "Name.Language.Code", "Name.Language.Country"], + batch_size=2, + ) + expected_prensor = prensor.create_prensor_from_descendant_nodes( + { + path.Path([]): prensor.RootNodeTensor(tf.constant(2, dtype=tf.int64)), + path.Path(["DocId"]): prensor.LeafNodeTensor( + tf.constant([0, 1], dtype=tf.int64), + tf.constant([10, 20], dtype=tf.int64), + True, + ), + path.Path(["Name"]): prensor.ChildNodeTensor( + tf.constant([0, 0, 0, 1], dtype=tf.int64), True + ), + path.Path(["Name", "Language"]): prensor.ChildNodeTensor( + tf.constant([0, 0, 2], dtype=tf.int64), True + ), + path.Path(["Name", "Language", "Code"]): prensor.LeafNodeTensor( + tf.constant([0, 1, 2], dtype=tf.int64), + tf.constant([b"en-us", b"en", b"en-gb"]), + True, + ), + path.Path(["Name", "Language", "Country"]): prensor.LeafNodeTensor( + tf.constant([0, 2], dtype=tf.int64), + tf.constant([b"us", b"gb"]), + True, + ), + } + ) + + for i, pren in enumerate(pq_ds): + if i == 0: + self._assertPrensorEqual(pren, expected_prensor) class ParquetDatasetWithExpressionTest(ParquetDatasetTestBase): - """This tests the public facing API, using the placeholder expression.""" - - def testPromoteAndProjectExpression(self): - filenames = [ - "struct2tensor/testdata/parquet_testdata/dremel_example.parquet" - ] - batch_size = 2 - exp = parquet.create_expression_from_parquet_file(filenames) - new_exp = promote.promote(exp, path.Path(["Name", "Language", "Code"]), - "new_code") - new_code_project_exp = project.project(new_exp, - [path.Path(["Name", "new_code"])]) - docid_project_exp = project.project(exp, [path.Path(["DocId"])]) - - pqds = parquet.calculate_parquet_values( - [new_code_project_exp, docid_project_exp], exp, filenames, batch_size) - - new_code_expected = prensor.create_prensor_from_descendant_nodes({ - path.Path([]): - prensor.RootNodeTensor(tf.constant(2, dtype=tf.int64)), - path.Path(["Name"]): - prensor.ChildNodeTensor( - tf.constant([0, 0, 0, 1], dtype=tf.int64), True), - path.Path(["Name", "new_code"]): - prensor.LeafNodeTensor( - tf.constant([0, 0, 2], dtype=tf.int64), - tf.constant([b"en-us", b"en", b"en-gb"]), True) - }) - - docid_expected = prensor.create_prensor_from_descendant_nodes({ - path.Path([]): - prensor.RootNodeTensor(tf.constant(2, dtype=tf.int64)), - path.Path(["DocId"]): - prensor.LeafNodeTensor( - tf.constant([0, 1], dtype=tf.int64), - tf.constant([10, 20], dtype=tf.int64), False) - }) - - for ele in pqds: - new_code_pren = ele[0] - docid_pren = ele[1] - - self._assertPrensorEqual(new_code_pren, new_code_expected) - self._assertPrensorEqual(docid_pren, docid_expected) + """This tests the public facing API, using the placeholder expression.""" + + def testPromoteAndProjectExpression(self): + filenames = ["struct2tensor/testdata/parquet_testdata/dremel_example.parquet"] + batch_size = 2 + exp = parquet.create_expression_from_parquet_file(filenames) + new_exp = promote.promote( + exp, path.Path(["Name", "Language", "Code"]), "new_code" + ) + new_code_project_exp = project.project( + new_exp, [path.Path(["Name", "new_code"])] + ) + docid_project_exp = project.project(exp, [path.Path(["DocId"])]) + + pqds = parquet.calculate_parquet_values( + [new_code_project_exp, docid_project_exp], exp, filenames, batch_size + ) + + new_code_expected = prensor.create_prensor_from_descendant_nodes( + { + path.Path([]): prensor.RootNodeTensor(tf.constant(2, dtype=tf.int64)), + path.Path(["Name"]): prensor.ChildNodeTensor( + tf.constant([0, 0, 0, 1], dtype=tf.int64), True + ), + path.Path(["Name", "new_code"]): prensor.LeafNodeTensor( + tf.constant([0, 0, 2], dtype=tf.int64), + tf.constant([b"en-us", b"en", b"en-gb"]), + True, + ), + } + ) + + docid_expected = prensor.create_prensor_from_descendant_nodes( + { + path.Path([]): prensor.RootNodeTensor(tf.constant(2, dtype=tf.int64)), + path.Path(["DocId"]): prensor.LeafNodeTensor( + tf.constant([0, 1], dtype=tf.int64), + tf.constant([10, 20], dtype=tf.int64), + False, + ), + } + ) + + for ele in pqds: + new_code_pren = ele[0] + docid_pren = ele[1] + + self._assertPrensorEqual(new_code_pren, new_code_expected) + self._assertPrensorEqual(docid_pren, docid_expected) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/parse_message_level_ex.py b/struct2tensor/expression_impl/parse_message_level_ex.py index 9f072f6..1fe5db5 100644 --- a/struct2tensor/expression_impl/parse_message_level_ex.py +++ b/struct2tensor/expression_impl/parse_message_level_ex.py @@ -97,183 +97,204 @@ def parse_message_level_ex( field_names: Set[ProtoFieldName], message_format: str = "binary", backing_str_tensor: Optional[tf.Tensor] = None, - honor_proto3_optional_semantics: bool = False + honor_proto3_optional_semantics: bool = False, ) -> Mapping[StrStep, struct2tensor_ops._ParsedField]: - """Parses regular fields, extensions, any casts, and map protos.""" - raw_field_names = _get_field_names_to_parse(desc, field_names) - regular_fields = list( - struct2tensor_ops.parse_message_level( - tensor_of_protos, - desc, - raw_field_names, - message_format=message_format, - backing_str_tensor=backing_str_tensor, - honor_proto3_optional_semantics=honor_proto3_optional_semantics)) - regular_field_map = {x.field_name: x for x in regular_fields} - - any_fields = _get_any_parsed_fields(desc, regular_field_map, field_names) - map_fields = _get_map_parsed_fields(desc, regular_field_map, field_names, - backing_str_tensor) - result = regular_field_map - result.update(any_fields) - result.update(map_fields) - return result + """Parses regular fields, extensions, any casts, and map protos.""" + raw_field_names = _get_field_names_to_parse(desc, field_names) + regular_fields = list( + struct2tensor_ops.parse_message_level( + tensor_of_protos, + desc, + raw_field_names, + message_format=message_format, + backing_str_tensor=backing_str_tensor, + honor_proto3_optional_semantics=honor_proto3_optional_semantics, + ) + ) + regular_field_map = {x.field_name: x for x in regular_fields} + + any_fields = _get_any_parsed_fields(desc, regular_field_map, field_names) + map_fields = _get_map_parsed_fields( + desc, regular_field_map, field_names, backing_str_tensor + ) + result = regular_field_map + result.update(any_fields) + result.update(map_fields) + return result def is_any_descriptor(desc: descriptor.Descriptor) -> bool: - """Returns true if it is an Any descriptor.""" - return desc.full_name == "google.protobuf.Any" - - -def get_full_name_from_any_step( - step: ProtoFieldName) -> Optional[ProtoFieldName]: - """Gets the full name of a protobuf from a google.protobuf.Any step. - - An any step is of the form (foo.com/bar.Baz). In this case the result would - be bar.Baz. - - Args: - step: the string of a step in a path. - - Returns: - the full name of a protobuf if the step is an any step, or None otherwise. - """ - if not step: - return None - if step[0] != "(": - return None - if step[-1] != ")": - return None - step_without_parens = step[1:-1] - return step_without_parens.split("/")[-1] - - -def _any_indices_with_type(type_url: struct2tensor_ops._ParsedField, - full_name: ProtoFullName) -> tf.Tensor: - """Returns the parent indices that have a type_url of full_name.""" - tensors_parsed = tf.compat.v1.string_split(type_url.value, delimiter="/") - second_column_shape = tf.stack([ - tf.shape(type_url.value, out_type=tf.int64)[0], - tf.constant(1, dtype=tf.int64) - ], - axis=0) - second_column = tf.reshape( - tf.sparse.to_dense( - tf.sparse.slice(tensors_parsed, tf.constant([0, 1], dtype=tf.int64), - second_column_shape), - default_value=""), [-1]) - equal_to_full_name = tf.equal(second_column, full_name) - return tf.boolean_mask(type_url.index, equal_to_full_name) - - -def _get_any_parsed_field(value_field: struct2tensor_ops._ParsedField, - type_url_field: struct2tensor_ops._ParsedField, - field_name: StrStep - ) -> struct2tensor_ops._ParsedField: - """Helper function for _get_any_parsed_fields.""" - full_name = get_full_name_from_any_step(field_name) - indices_with_type = _any_indices_with_type(type_url_field, full_name) - [index_to_solution_index, index_to_values - ] = struct2tensor_ops.equi_join_indices(indices_with_type, value_field.index) - solution_index = tf.gather(indices_with_type, index_to_solution_index) - solution_value = tf.gather(value_field.value, index_to_values) - # TODO(martinz): make _ParsedField public. - return struct2tensor_ops._ParsedField( # pylint: disable=protected-access - field_name=field_name, - field_descriptor=None, - index=solution_index, - value=solution_value) + """Returns true if it is an Any descriptor.""" + return desc.full_name == "google.protobuf.Any" + + +def get_full_name_from_any_step(step: ProtoFieldName) -> Optional[ProtoFieldName]: + """Gets the full name of a protobuf from a google.protobuf.Any step. + + An any step is of the form (foo.com/bar.Baz). In this case the result would + be bar.Baz. + + Args: + step: the string of a step in a path. + + Returns: + the full name of a protobuf if the step is an any step, or None otherwise. + """ + if not step: + return None + if step[0] != "(": + return None + if step[-1] != ")": + return None + step_without_parens = step[1:-1] + return step_without_parens.split("/")[-1] + + +def _any_indices_with_type( + type_url: struct2tensor_ops._ParsedField, full_name: ProtoFullName +) -> tf.Tensor: + """Returns the parent indices that have a type_url of full_name.""" + tensors_parsed = tf.compat.v1.string_split(type_url.value, delimiter="/") + second_column_shape = tf.stack( + [ + tf.shape(type_url.value, out_type=tf.int64)[0], + tf.constant(1, dtype=tf.int64), + ], + axis=0, + ) + second_column = tf.reshape( + tf.sparse.to_dense( + tf.sparse.slice( + tensors_parsed, tf.constant([0, 1], dtype=tf.int64), second_column_shape + ), + default_value="", + ), + [-1], + ) + equal_to_full_name = tf.equal(second_column, full_name) + return tf.boolean_mask(type_url.index, equal_to_full_name) + + +def _get_any_parsed_field( + value_field: struct2tensor_ops._ParsedField, + type_url_field: struct2tensor_ops._ParsedField, + field_name: StrStep, +) -> struct2tensor_ops._ParsedField: + """Helper function for _get_any_parsed_fields.""" + full_name = get_full_name_from_any_step(field_name) + indices_with_type = _any_indices_with_type(type_url_field, full_name) + [index_to_solution_index, index_to_values] = struct2tensor_ops.equi_join_indices( + indices_with_type, value_field.index + ) + solution_index = tf.gather(indices_with_type, index_to_solution_index) + solution_value = tf.gather(value_field.value, index_to_values) + # TODO(martinz): make _ParsedField public. + return struct2tensor_ops._ParsedField( # pylint: disable=protected-access + field_name=field_name, + field_descriptor=None, + index=solution_index, + value=solution_value, + ) def _get_any_parsed_fields( desc: descriptor.Descriptor, raw_parsed_fields: Mapping[StrStep, struct2tensor_ops._ParsedField], - field_names: Set[StrStep] + field_names: Set[StrStep], ) -> Mapping[StrStep, struct2tensor_ops._ParsedField]: - """Gets the _ParsedField sequence for an Any protobuf.""" - if not is_any_descriptor(desc): - return {} + """Gets the _ParsedField sequence for an Any protobuf.""" + if not is_any_descriptor(desc): + return {} - result = [] # type: List[struct2tensor_ops._ParsedField] + result = [] # type: List[struct2tensor_ops._ParsedField] - for x in field_names: - if path.is_extension(x): - result.append( - _get_any_parsed_field(raw_parsed_fields["value"], - raw_parsed_fields["type_url"], x)) - return {x.field_name: x for x in result} + for x in field_names: + if path.is_extension(x): + result.append( + _get_any_parsed_field( + raw_parsed_fields["value"], raw_parsed_fields["type_url"], x + ) + ) + return {x.field_name: x for x in result} def _get_field_names_to_parse( - desc: descriptor.Descriptor, - needed_field_names: Set[StrStep]) -> Sequence[ProtoFieldName]: - """Gets the field names to parse from the original protobuf.""" - result = set() # Set[ProtoFieldName] - for x in needed_field_names: - if path.is_map_indexing_step(x): - map_field_name, _ = path.parse_map_indexing_step(x) - result.add(map_field_name) - elif path.is_extension(x) and is_any_descriptor(desc): - result.add("type_url") - result.add("value") - else: - result.add(x) - return list(result) + desc: descriptor.Descriptor, needed_field_names: Set[StrStep] +) -> Sequence[ProtoFieldName]: + """Gets the field names to parse from the original protobuf.""" + result = set() # Set[ProtoFieldName] + for x in needed_field_names: + if path.is_map_indexing_step(x): + map_field_name, _ = path.parse_map_indexing_step(x) + result.add(map_field_name) + elif path.is_extension(x) and is_any_descriptor(desc): + result.add("type_url") + result.add("value") + else: + result.add(x) + return list(result) def _get_map_parsed_fields( desc: descriptor.Descriptor, regular_fields: Mapping[StrStep, struct2tensor_ops._ParsedField], field_names: Set[StrStep], - backing_str_tensor: Optional[tf.Tensor] = None + backing_str_tensor: Optional[tf.Tensor] = None, ) -> Mapping[StrStep, struct2tensor_ops._ParsedField]: - """Gets the map proto ParsedFields. - - field_names includes all the fields: map fields, any fields, and - regular fields. - - Args: - desc: the descriptor of the parent proto. - regular_fields: the fields that are parsed directly from the proto. - field_names: all fields needed: map fields, any fields, and regular fields. - backing_str_tensor: a string tensor representing the root serialized proto. - This is passed to keep string_views of the tensor valid for all children - of the root expression - - Returns: - A map from field names to ParsedFields, only for the field names of the form - foo[bar]. - """ - maps_to_parse = collections.defaultdict(dict) - for x in field_names: - if path.is_map_indexing_step(x): - map_field_name, key = path.parse_map_indexing_step(x) - maps_to_parse[map_field_name][key] = x - result_as_list = [] - for map_field_name, v in maps_to_parse.items(): - parsed_map_field = regular_fields[map_field_name] - keys_needed = list(v.keys()) - map_field_value = parsed_map_field.value - map_field_index = parsed_map_field.index - map_field_desc = desc.fields_by_name[map_field_name].message_type - values_and_parent_indices = struct2tensor_ops.parse_proto_map( - map_field_value, map_field_index, map_field_desc, keys_needed, - backing_str_tensor) - for map_key, [value, parent_index] in zip(keys_needed, - values_and_parent_indices): - result_as_list.append( - struct2tensor_ops._ParsedField( - field_name=v[map_key], - field_descriptor=None, - index=parent_index, - value=value)) - return {x.field_name: x for x in result_as_list} - - -def _get_parsed_field(data: Sequence[struct2tensor_ops._ParsedField], - field_name: ProtoFieldName - ) -> Optional[struct2tensor_ops._ParsedField]: - for x in data: - if x.field_name == field_name: - return x - return None + """Gets the map proto ParsedFields. + + field_names includes all the fields: map fields, any fields, and + regular fields. + + Args: + desc: the descriptor of the parent proto. + regular_fields: the fields that are parsed directly from the proto. + field_names: all fields needed: map fields, any fields, and regular fields. + backing_str_tensor: a string tensor representing the root serialized proto. + This is passed to keep string_views of the tensor valid for all children + of the root expression + + Returns: + A map from field names to ParsedFields, only for the field names of the form + foo[bar]. + """ + maps_to_parse = collections.defaultdict(dict) + for x in field_names: + if path.is_map_indexing_step(x): + map_field_name, key = path.parse_map_indexing_step(x) + maps_to_parse[map_field_name][key] = x + result_as_list = [] + for map_field_name, v in maps_to_parse.items(): + parsed_map_field = regular_fields[map_field_name] + keys_needed = list(v.keys()) + map_field_value = parsed_map_field.value + map_field_index = parsed_map_field.index + map_field_desc = desc.fields_by_name[map_field_name].message_type + values_and_parent_indices = struct2tensor_ops.parse_proto_map( + map_field_value, + map_field_index, + map_field_desc, + keys_needed, + backing_str_tensor, + ) + for map_key, [value, parent_index] in zip( + keys_needed, values_and_parent_indices + ): + result_as_list.append( + struct2tensor_ops._ParsedField( + field_name=v[map_key], + field_descriptor=None, + index=parent_index, + value=value, + ) + ) + return {x.field_name: x for x in result_as_list} + + +def _get_parsed_field( + data: Sequence[struct2tensor_ops._ParsedField], field_name: ProtoFieldName +) -> Optional[struct2tensor_ops._ParsedField]: + for x in data: + if x.field_name == field_name: + return x + return None diff --git a/struct2tensor/expression_impl/parse_message_level_ex_test.py b/struct2tensor/expression_impl/parse_message_level_ex_test.py index c6e1aed..1f06b7c 100644 --- a/struct2tensor/expression_impl/parse_message_level_ex_test.py +++ b/struct2tensor/expression_impl/parse_message_level_ex_test.py @@ -22,7 +22,7 @@ from absl.testing import absltest, parameterized from google.protobuf import text_format from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor.expression_impl import parse_message_level_ex @@ -40,142 +40,149 @@ def _run_parse_message_level_ex(proto_list, fields, message_format="binary"): - if message_format == "text": - serialized = [text_format.MessageToString(x) for x in proto_list] - elif message_format == "binary": - serialized = [x.SerializeToString() for x in proto_list] - else: - raise ValueError('Message format must be one of "text", "binary"') - parsed_field_dict = parse_message_level_ex.parse_message_level_ex( - tf.constant(serialized), proto_list[0].DESCRIPTOR, fields, message_format) - sess_input = {} - for key, value in parsed_field_dict.items(): - local_dict = {} - local_dict[_INDEX] = value.index - local_dict[_VALUE] = value.value - sess_input[key] = local_dict - return sess_input + if message_format == "text": + serialized = [text_format.MessageToString(x) for x in proto_list] + elif message_format == "binary": + serialized = [x.SerializeToString() for x in proto_list] + else: + raise ValueError('Message format must be one of "text", "binary"') + parsed_field_dict = parse_message_level_ex.parse_message_level_ex( + tf.constant(serialized), proto_list[0].DESCRIPTOR, fields, message_format + ) + sess_input = {} + for key, value in parsed_field_dict.items(): + local_dict = {} + local_dict[_INDEX] = value.index + local_dict[_VALUE] = value.value + sess_input[key] = local_dict + return sess_input def _create_any(x): - """Create an any object from a protobuf.""" - new_any = test_any_pb2.MessageWithAny().my_any - new_any.Pack(x) - return new_any + """Create an any object from a protobuf.""" + new_any = test_any_pb2.MessageWithAny().my_any + new_any.Pack(x) + return new_any def _create_any_protos(): - my_value_0 = test_pb2.AllSimple() - my_value_0.optional_int32 = 0 - my_value_1 = test_pb2.UserInfo() - my_value_2 = test_pb2.AllSimple() - my_value_2.optional_int32 = 20 - return [_create_any(x) for x in [my_value_0, my_value_1, my_value_2]] + my_value_0 = test_pb2.AllSimple() + my_value_0.optional_int32 = 0 + my_value_1 = test_pb2.UserInfo() + my_value_2 = test_pb2.AllSimple() + my_value_2.optional_int32 = 20 + return [_create_any(x) for x in [my_value_0, my_value_1, my_value_2]] def _get_optional_int32(serialized_all_simple): - """Take a serialized test_pb2.AllSimple object and extract optional_int32.""" - holder = test_pb2.AllSimple() - holder.ParseFromString(serialized_all_simple) - return holder.optional_int32 + """Take a serialized test_pb2.AllSimple object and extract optional_int32.""" + holder = test_pb2.AllSimple() + holder.ParseFromString(serialized_all_simple) + return holder.optional_int32 def _get_empty_all_simple(): - """Take a serialized test_pb2.AllSimple object and extract optional_int32.""" - return test_pb2.AllSimple().SerializeToString() + """Take a serialized test_pb2.AllSimple object and extract optional_int32.""" + return test_pb2.AllSimple().SerializeToString() @test_util.run_all_in_graph_and_eager_modes class ParseMessageLevelExTest(parameterized.TestCase, tf.test.TestCase): - - # TODO(askerryryan): Consider supporting Any types for text format. Currently - # only binary format is supported. - def test_any_field(self): - original_protos = _create_any_protos() - result = _run_parse_message_level_ex(original_protos, {_ALLSIMPLE}) - self.assertIn(_ALLSIMPLE, result) - self.assertIn(_INDEX, result[_ALLSIMPLE]) - self.assertIn(_VALUE, result[_ALLSIMPLE]) - - self.assertAllEqual(result[_ALLSIMPLE][_INDEX], [0, 2]) - result_optional_int32 = [ - _get_optional_int32(x) - for x in self.evaluate(result[_ALLSIMPLE][_VALUE]) - ] - self.assertAllEqual(result_optional_int32, [0, 20]) - - def test_any_field_no_special(self): - result = _run_parse_message_level_ex(_create_any_protos(), {_TYPE_URL}) - self.assertIn(_TYPE_URL, result) - self.assertAllEqual(result[_TYPE_URL][_INDEX], [0, 1, 2]) - - self.assertAllEqual( - result[_TYPE_URL][_VALUE], - [_ALLSIMPLE_NO_PARENS, _USERINFO_NO_PARENS, _ALLSIMPLE_NO_PARENS]) - - def test_any_field_special_and_type_url(self): - result = _run_parse_message_level_ex(_create_any_protos(), - {_TYPE_URL, _ALLSIMPLE}) - self.assertIn(_TYPE_URL, result) - self.assertAllEqual(result[_TYPE_URL][_INDEX], [0, 1, 2]) - - self.assertAllEqual( - result[_TYPE_URL][_VALUE], - [_ALLSIMPLE_NO_PARENS, _USERINFO_NO_PARENS, _ALLSIMPLE_NO_PARENS]) - - self.assertAllEqual(result[_ALLSIMPLE][_INDEX], [0, 2]) - actual_values = [ - _get_optional_int32(x) - for x in self.evaluate(result[_ALLSIMPLE][_VALUE]) - ] - self.assertAllEqual(actual_values, [0, 20]) - - def test_full_name_from_any_step(self): - self.assertEqual( - parse_message_level_ex.get_full_name_from_any_step(_ALLSIMPLE), - "struct2tensor.test.AllSimple") - self.assertEqual( - parse_message_level_ex.get_full_name_from_any_step(_USERINFO), - "struct2tensor.test.UserInfo") - self.assertIsNone( - parse_message_level_ex.get_full_name_from_any_step("broken")) - self.assertIsNone( - parse_message_level_ex.get_full_name_from_any_step("(broken")) - self.assertIsNone( - parse_message_level_ex.get_full_name_from_any_step("broken)")) - - @parameterized.named_parameters( - [dict(testcase_name=f, message_format=f) for f in _MESSAGE_FORMATS]) - def test_normal_field(self, message_format): - """Test three messages with a repeated string.""" - all_simple = test_pb2.AllSimple() - all_simple.repeated_string.append("foo") - all_simple.repeated_string.append("foo2") - all_simple_empty = test_pb2.AllSimple() - - result = _run_parse_message_level_ex( - [all_simple, all_simple_empty, all_simple, all_simple], - {"repeated_string"}, message_format) - self.assertNotIn("repeated_bool", result) - - self.assertAllEqual(result["repeated_string"][_INDEX], [0, 0, 2, 2, 3, 3]) - self.assertAllEqual(result["repeated_string"][_VALUE], - [b"foo", b"foo2", b"foo", b"foo2", b"foo", b"foo2"]) - - @parameterized.named_parameters( - [dict(testcase_name=f, message_format=f) for f in _MESSAGE_FORMATS]) - def test_bool_key_type(self, message_format): - map_field = "bool_string_map[1]" - message_with_map_0 = test_map_pb2.MessageWithMap() - message_with_map_0.bool_string_map[False] = "hello" - message_with_map_1 = test_map_pb2.MessageWithMap() - message_with_map_1.bool_string_map[True] = "goodbye" - result = _run_parse_message_level_ex( - [message_with_map_0, message_with_map_1], {map_field}, message_format) - self.assertIn(map_field, result) - self.assertAllEqual(result[map_field][_VALUE], [b"goodbye"]) - self.assertAllEqual(result[map_field][_INDEX], [1]) + # TODO(askerryryan): Consider supporting Any types for text format. Currently + # only binary format is supported. + def test_any_field(self): + original_protos = _create_any_protos() + result = _run_parse_message_level_ex(original_protos, {_ALLSIMPLE}) + self.assertIn(_ALLSIMPLE, result) + self.assertIn(_INDEX, result[_ALLSIMPLE]) + self.assertIn(_VALUE, result[_ALLSIMPLE]) + + self.assertAllEqual(result[_ALLSIMPLE][_INDEX], [0, 2]) + result_optional_int32 = [ + _get_optional_int32(x) for x in self.evaluate(result[_ALLSIMPLE][_VALUE]) + ] + self.assertAllEqual(result_optional_int32, [0, 20]) + + def test_any_field_no_special(self): + result = _run_parse_message_level_ex(_create_any_protos(), {_TYPE_URL}) + self.assertIn(_TYPE_URL, result) + self.assertAllEqual(result[_TYPE_URL][_INDEX], [0, 1, 2]) + + self.assertAllEqual( + result[_TYPE_URL][_VALUE], + [_ALLSIMPLE_NO_PARENS, _USERINFO_NO_PARENS, _ALLSIMPLE_NO_PARENS], + ) + + def test_any_field_special_and_type_url(self): + result = _run_parse_message_level_ex( + _create_any_protos(), {_TYPE_URL, _ALLSIMPLE} + ) + self.assertIn(_TYPE_URL, result) + self.assertAllEqual(result[_TYPE_URL][_INDEX], [0, 1, 2]) + + self.assertAllEqual( + result[_TYPE_URL][_VALUE], + [_ALLSIMPLE_NO_PARENS, _USERINFO_NO_PARENS, _ALLSIMPLE_NO_PARENS], + ) + + self.assertAllEqual(result[_ALLSIMPLE][_INDEX], [0, 2]) + actual_values = [ + _get_optional_int32(x) for x in self.evaluate(result[_ALLSIMPLE][_VALUE]) + ] + self.assertAllEqual(actual_values, [0, 20]) + + def test_full_name_from_any_step(self): + self.assertEqual( + parse_message_level_ex.get_full_name_from_any_step(_ALLSIMPLE), + "struct2tensor.test.AllSimple", + ) + self.assertEqual( + parse_message_level_ex.get_full_name_from_any_step(_USERINFO), + "struct2tensor.test.UserInfo", + ) + self.assertIsNone(parse_message_level_ex.get_full_name_from_any_step("broken")) + self.assertIsNone(parse_message_level_ex.get_full_name_from_any_step("(broken")) + self.assertIsNone(parse_message_level_ex.get_full_name_from_any_step("broken)")) + + @parameterized.named_parameters( + [dict(testcase_name=f, message_format=f) for f in _MESSAGE_FORMATS] + ) + def test_normal_field(self, message_format): + """Test three messages with a repeated string.""" + all_simple = test_pb2.AllSimple() + all_simple.repeated_string.append("foo") + all_simple.repeated_string.append("foo2") + all_simple_empty = test_pb2.AllSimple() + + result = _run_parse_message_level_ex( + [all_simple, all_simple_empty, all_simple, all_simple], + {"repeated_string"}, + message_format, + ) + self.assertNotIn("repeated_bool", result) + + self.assertAllEqual(result["repeated_string"][_INDEX], [0, 0, 2, 2, 3, 3]) + self.assertAllEqual( + result["repeated_string"][_VALUE], + [b"foo", b"foo2", b"foo", b"foo2", b"foo", b"foo2"], + ) + + @parameterized.named_parameters( + [dict(testcase_name=f, message_format=f) for f in _MESSAGE_FORMATS] + ) + def test_bool_key_type(self, message_format): + map_field = "bool_string_map[1]" + message_with_map_0 = test_map_pb2.MessageWithMap() + message_with_map_0.bool_string_map[False] = "hello" + message_with_map_1 = test_map_pb2.MessageWithMap() + message_with_map_1.bool_string_map[True] = "goodbye" + result = _run_parse_message_level_ex( + [message_with_map_0, message_with_map_1], {map_field}, message_format + ) + self.assertIn(map_field, result) + self.assertAllEqual(result[map_field][_VALUE], [b"goodbye"]) + self.assertAllEqual(result[map_field][_INDEX], [1]) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/placeholder.py b/struct2tensor/expression_impl/placeholder.py index c6e3444..2760f43 100644 --- a/struct2tensor/expression_impl/placeholder.py +++ b/struct2tensor/expression_impl/placeholder.py @@ -39,167 +39,179 @@ from struct2tensor.expression_impl import map_prensor_to_prensor as mpp -def create_expression_from_schema( - schema: mpp.Schema) -> "_PlaceholderRootExpression": - """Creates a placeholder expression from a parquet schema. +def create_expression_from_schema(schema: mpp.Schema) -> "_PlaceholderRootExpression": + """Creates a placeholder expression from a parquet schema. - Args: - schema: The schema that describes the prensor tree that this placeholder - represents. + Args: + schema: The schema that describes the prensor tree that this placeholder + represents. - Returns: - A PlaceholderRootExpression that should be used as the root of an expression - graph. - """ - return _PlaceholderRootExpression(schema) + Returns: + A PlaceholderRootExpression that should be used as the root of an expression + graph. + """ + return _PlaceholderRootExpression(schema) def _is_placeholder_expression(expr: expression.Expression) -> bool: - """Returns true if an expression is a ParquetExpression.""" - return isinstance(expr, - (_PlaceholderRootExpression, _PlaceholderChildExpression)) + """Returns true if an expression is a ParquetExpression.""" + return isinstance(expr, (_PlaceholderRootExpression, _PlaceholderChildExpression)) def get_placeholder_paths_from_graph( - graph: calculate.ExpressionGraph) -> List[path.Path]: - """Gets all placeholder paths from an expression graph. + graph: calculate.ExpressionGraph, +) -> List[path.Path]: + """Gets all placeholder paths from an expression graph. - This finds all leaf placeholder expressions in an expression graph, and gets - the path of these expressions. + This finds all leaf placeholder expressions in an expression graph, and gets + the path of these expressions. - Args: - graph: expression graph + Args: + graph: expression graph - Returns: - a list of paths of placeholder expressions - """ - expressions = [ - x for x in graph.get_expressions_needed() - if (_is_placeholder_expression(x) and x.is_leaf) - ] - expressions = typing.cast(List[_PlaceholderExpression], expressions) - return [e.get_path() for e in expressions] + Returns: + a list of paths of placeholder expressions + """ + expressions = [ + x + for x in graph.get_expressions_needed() + if (_is_placeholder_expression(x) and x.is_leaf) + ] + expressions = typing.cast(List[_PlaceholderExpression], expressions) + return [e.get_path() for e in expressions] class _PlaceholderChildExpression(expression.Expression): - """A child or leaf parquet expression.""" - - # pylint: disable=protected-access - def __init__(self, parent: "_PlaceholderExpression", step: path.Step, - schema: mpp.Schema): - super().__init__( - schema.is_repeated, schema.type, schema_feature=schema.schema_feature) - self._parent = parent - self._step = step - self._schema = schema - - @property - def schema(self): - return self._schema - - @property - def is_leaf(self) -> bool: - return not self._schema._children - - def get_path(self) -> path.Path: - return self._parent.get_path().get_child(self._step) - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._parent] - - def calculate( # pytype: disable=signature-mismatch # overriding-parameter-type-checks - self, - source_tensors: Sequence[mpp._TreeAsNode], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> mpp._TreeAsNode: - if side_info: - return mpp._tree_as_node(side_info) - [parent] = source_tensors - if not isinstance( - parent, (mpp._PrensorAsLeafNodeTensor, mpp._PrensorAsChildNodeTensor, - mpp._PrensorAsRootNodeTensor)): - raise ValueError("calculate() of Parent did not return a " - "_PrensorAsNodeTensor") - my_pren = parent.prensor.get_child(self._step) - if not my_pren: - raise ValueError("step " + self._step + " does not exist in prensor: " + - str(parent.prensor)) - return mpp._tree_as_node(my_pren) - - def calculation_is_identity(self) -> bool: - return False - - def calculation_equal(self, expr: expression.Expression) -> bool: - return self is expr - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - if field_name not in self._schema.known_field_names(): - return None - child_schema = self._schema.get_child(field_name) - return _PlaceholderChildExpression(self, field_name, child_schema) - - def known_field_names(self) -> FrozenSet[path.Step]: - return self._schema.known_field_names() + """A child or leaf parquet expression.""" + + # pylint: disable=protected-access + def __init__( + self, parent: "_PlaceholderExpression", step: path.Step, schema: mpp.Schema + ): + super().__init__( + schema.is_repeated, schema.type, schema_feature=schema.schema_feature + ) + self._parent = parent + self._step = step + self._schema = schema + + @property + def schema(self): + return self._schema + + @property + def is_leaf(self) -> bool: + return not self._schema._children + + def get_path(self) -> path.Path: + return self._parent.get_path().get_child(self._step) + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._parent] + + def calculate( # pytype: disable=signature-mismatch # overriding-parameter-type-checks + self, + source_tensors: Sequence[mpp._TreeAsNode], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> mpp._TreeAsNode: + if side_info: + return mpp._tree_as_node(side_info) + [parent] = source_tensors + if not isinstance( + parent, + ( + mpp._PrensorAsLeafNodeTensor, + mpp._PrensorAsChildNodeTensor, + mpp._PrensorAsRootNodeTensor, + ), + ): + raise ValueError( + "calculate() of Parent did not return a _PrensorAsNodeTensor" + ) + my_pren = parent.prensor.get_child(self._step) + if not my_pren: + raise ValueError( + "step " + + self._step + + " does not exist in prensor: " + + str(parent.prensor) + ) + return mpp._tree_as_node(my_pren) + + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + return self is expr + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + if field_name not in self._schema.known_field_names(): + return None + child_schema = self._schema.get_child(field_name) + return _PlaceholderChildExpression(self, field_name, child_schema) + + def known_field_names(self) -> FrozenSet[path.Step]: + return self._schema.known_field_names() class _PlaceholderRootExpression(expression.Expression): - """An expression that calculates to the side_info passed in at calculate().""" + """An expression that calculates to the side_info passed in at calculate().""" - def __init__(self, schema: mpp.Schema): - """Initializes the root of the placeholder expression. + def __init__(self, schema: mpp.Schema): + """Initializes the root of the placeholder expression. - Args: - schema: the schema that represents what the expression tree looks like. - """ - super().__init__(True, None) - self._schema = schema + Args: + schema: the schema that represents what the expression tree looks like. + """ + super().__init__(True, None) + self._schema = schema - @property - def schema(self): - return self._schema + @property + def schema(self): + return self._schema - @property - def is_leaf(self) -> bool: - return False + @property + def is_leaf(self) -> bool: + return False - def get_path(self) -> path.Path: - return path.Path([]) + def get_path(self) -> path.Path: + return path.Path([]) - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [] + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [] - def calculate( # pytype: disable=signature-mismatch # overriding-parameter-type-checks - self, source_tensors: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: prensor.Prensor) -> mpp._TreeAsNode: - if source_tensors: - raise ValueError("_PlaceholderRootExpression has no sources") - if side_info: - return mpp._tree_as_node(side_info) # pylint: disable=protected-access - raise ValueError("_PlaceholderRootExpression requires side_info") + def calculate( # pytype: disable=signature-mismatch # overriding-parameter-type-checks + self, + source_tensors: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: prensor.Prensor, + ) -> mpp._TreeAsNode: + if source_tensors: + raise ValueError("_PlaceholderRootExpression has no sources") + if side_info: + return mpp._tree_as_node(side_info) # pylint: disable=protected-access + raise ValueError("_PlaceholderRootExpression requires side_info") - def calculation_is_identity(self) -> bool: - return False + def calculation_is_identity(self) -> bool: + return False - def calculation_equal(self, expr: expression.Expression) -> bool: - return self is expr + def calculation_equal(self, expr: expression.Expression) -> bool: + return self is expr - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - if field_name not in self._schema.known_field_names(): - return None - child_schema = self._schema.get_child(field_name) - return _PlaceholderChildExpression(self, field_name, child_schema) + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + if field_name not in self._schema.known_field_names(): + return None + child_schema = self._schema.get_child(field_name) + return _PlaceholderChildExpression(self, field_name, child_schema) - def __str__(self) -> str: - return f"_PlaceholderRootExpression: {str(self._schema)}" + def __str__(self) -> str: + return f"_PlaceholderRootExpression: {str(self._schema)}" - def known_field_names(self) -> FrozenSet[path.Step]: - return self._schema.known_field_names() + def known_field_names(self) -> FrozenSet[path.Step]: + return self._schema.known_field_names() -_PlaceholderExpression = Union[_PlaceholderRootExpression, - _PlaceholderChildExpression] +_PlaceholderExpression = Union[_PlaceholderRootExpression, _PlaceholderChildExpression] diff --git a/struct2tensor/expression_impl/placeholder_test.py b/struct2tensor/expression_impl/placeholder_test.py index 957a4a4..52f28ce 100644 --- a/struct2tensor/expression_impl/placeholder_test.py +++ b/struct2tensor/expression_impl/placeholder_test.py @@ -16,7 +16,7 @@ import tensorflow as tf from absl.testing import absltest from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor import calculate, path, prensor @@ -27,80 +27,73 @@ @test_util.run_all_in_graph_and_eager_modes class PlaceholderTest(tf.test.TestCase): + def testPlaceholderExpression(self): + pren = prensor_test_util.create_nested_prensor() + expected_pren = prensor.create_prensor_from_descendant_nodes( + { + path.Path([]): prensor.RootNodeTensor(tf.constant(3, dtype=tf.int64)), + path.Path(["new_friends"]): prensor.LeafNodeTensor( + tf.constant([0, 1, 1, 1, 2], dtype=tf.int64), + tf.constant(["a", "b", "c", "d", "e"], dtype=tf.string), + True, + ), + } + ) - def testPlaceholderExpression(self): - pren = prensor_test_util.create_nested_prensor() - expected_pren = prensor.create_prensor_from_descendant_nodes({ - path.Path([]): - prensor.RootNodeTensor(tf.constant(3, dtype=tf.int64)), - path.Path(["new_friends"]): - prensor.LeafNodeTensor( - tf.constant([0, 1, 1, 1, 2], dtype=tf.int64), - tf.constant(["a", "b", "c", "d", "e"], dtype=tf.string), True) - }) - - root_schema = mpp.create_schema( - is_repeated=True, - children={ - "doc": { - "is_repeated": True, - "children": { - "bar": { - "is_repeated": True, - "dtype": tf.string + root_schema = mpp.create_schema( + is_repeated=True, + children={ + "doc": { + "is_repeated": True, + "children": { + "bar": {"is_repeated": True, "dtype": tf.string}, + "keep_me": {"is_repeated": False, "dtype": tf.bool}, }, - "keep_me": { - "is_repeated": False, - "dtype": tf.bool - } - } + }, + "user": { + "is_repeated": True, + "children": {"friends": {"is_repeated": True, "dtype": tf.string}}, + }, }, - "user": { - "is_repeated": True, - "children": { - "friends": { - "is_repeated": True, - "dtype": tf.string - } - } - } - }) + ) - exp = placeholder.create_expression_from_schema(root_schema) - promote_exp = promote.promote(exp, path.Path(["user", "friends"]), - "new_friends") - project_exp = project.project(promote_exp, [path.Path(["new_friends"])]) - new_friends_exp = project_exp.get_descendant(path.Path(["new_friends"])) + exp = placeholder.create_expression_from_schema(root_schema) + promote_exp = promote.promote( + exp, path.Path(["user", "friends"]), "new_friends" + ) + project_exp = project.project(promote_exp, [path.Path(["new_friends"])]) + new_friends_exp = project_exp.get_descendant(path.Path(["new_friends"])) - result = calculate.calculate_values([new_friends_exp], - feed_dict={exp: pren}) + result = calculate.calculate_values([new_friends_exp], feed_dict={exp: pren}) - res_node = result[0] - exp_node = expected_pren.get_descendant(path.Path(["new_friends"])).node + res_node = result[0] + exp_node = expected_pren.get_descendant(path.Path(["new_friends"])).node - self.assertAllEqual(res_node.is_repeated, exp_node.is_repeated) - self.assertAllEqual(res_node.values, exp_node.values) - self.assertAllEqual(res_node.parent_index, exp_node.parent_index) + self.assertAllEqual(res_node.is_repeated, exp_node.is_repeated) + self.assertAllEqual(res_node.values, exp_node.values) + self.assertAllEqual(res_node.parent_index, exp_node.parent_index) - def testCreateExpressionFromSchema(self): - root_schema = mpp.create_schema(is_repeated=True, children={}) - exp = placeholder.create_expression_from_schema(root_schema) - pren = prensor.create_prensor_from_descendant_nodes( - {path.Path([]): prensor.RootNodeTensor(tf.constant(1, dtype=tf.int64))}) - result = calculate.calculate_values([exp], feed_dict={exp: pren}) - res_node = result[0] - exp_node = pren.get_descendant(path.Path([])).node + def testCreateExpressionFromSchema(self): + root_schema = mpp.create_schema(is_repeated=True, children={}) + exp = placeholder.create_expression_from_schema(root_schema) + pren = prensor.create_prensor_from_descendant_nodes( + {path.Path([]): prensor.RootNodeTensor(tf.constant(1, dtype=tf.int64))} + ) + result = calculate.calculate_values([exp], feed_dict={exp: pren}) + res_node = result[0] + exp_node = pren.get_descendant(path.Path([])).node - self.assertAllEqual(res_node.is_repeated, exp_node.is_repeated) - self.assertAllEqual(res_node.size, exp_node.size) + self.assertAllEqual(res_node.is_repeated, exp_node.is_repeated) + self.assertAllEqual(res_node.size, exp_node.size) - def testPlaceholderRootExpressionRequiresSideInfo(self): - root_schema = mpp.create_schema(is_repeated=True, children={}) - exp = placeholder.create_expression_from_schema(root_schema) - with self.assertRaisesRegex( - ValueError, "_PlaceholderRootExpression requires side_info"): - calculate.calculate_values([exp], feed_dict={exp: None}) + def testPlaceholderRootExpressionRequiresSideInfo(self): + root_schema = mpp.create_schema(is_repeated=True, children={}) + exp = placeholder.create_expression_from_schema(root_schema) + with self.assertRaisesRegex( + ValueError, "_PlaceholderRootExpression requires side_info" + ): + calculate.calculate_values([exp], feed_dict={exp: None}) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/project.py b/struct2tensor/expression_impl/project.py index 4c69b49..2f1996b 100644 --- a/struct2tensor/expression_impl/project.py +++ b/struct2tensor/expression_impl/project.py @@ -33,84 +33,88 @@ from struct2tensor import calculate_options, expression, path, prensor -def project(expr: expression.Expression, - paths: Sequence[path.Path]) -> expression.Expression: - """Select a subtree. - - Paths not selected are removed. - Paths that are selected are "known", such that if calculate_prensors is - called, they will be in the result. - - Args: - expr: the original expression. - paths: the paths to include. - - Returns: - A projected expression. - """ - missing_paths = [p for p in paths if expr.get_descendant(p) is None] - if missing_paths: - raise ValueError("{} Path(s) missing in project: {}".format( - len(missing_paths), ", ".join([str(x) for x in missing_paths]))) - return _ProjectExpression(expr, paths) +def project( + expr: expression.Expression, paths: Sequence[path.Path] +) -> expression.Expression: + """Select a subtree. + + Paths not selected are removed. + Paths that are selected are "known", such that if calculate_prensors is + called, they will be in the result. + + Args: + expr: the original expression. + paths: the paths to include. + + Returns: + A projected expression. + """ + missing_paths = [p for p in paths if expr.get_descendant(p) is None] + if missing_paths: + raise ValueError( + "{} Path(s) missing in project: {}".format( + len(missing_paths), ", ".join([str(x) for x in missing_paths]) + ) + ) + return _ProjectExpression(expr, paths) def _group_paths_by_first_step( - paths: Sequence[path.Path]) -> Mapping[path.Step, List[path.Path]]: - result = collections.defaultdict(list) - for p in paths: - if p: - first_step = p.field_list[0] - result[first_step].append(p.suffix(1)) - return result + paths: Sequence[path.Path], +) -> Mapping[path.Step, List[path.Path]]: + result = collections.defaultdict(list) + for p in paths: + if p: + first_step = p.field_list[0] + result[first_step].append(p.suffix(1)) + return result class _ProjectExpression(expression.Expression): - """Project all subfields of an expression.""" - - def __init__( - self, - origin: expression.Expression, - paths: Sequence[path.Path], - ): - super().__init__( - origin.is_repeated, - origin.type, - origin.schema_feature, - validate_step_format=origin.validate_step_format, - ) - self._paths_map = _group_paths_by_first_step(paths) - self._origin = origin - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._origin] - - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - if len(sources) != 1: - raise ValueError("Expected one source.") - return sources[0] - - def calculation_is_identity(self) -> bool: - return True - - def calculation_equal(self, expr: expression.Expression) -> bool: - return expr.calculation_is_identity() - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - paths = self._paths_map.get(field_name) - if paths is None: - return None - original = self._origin.get_child(field_name) - if original is None: - raise ValueError( - f"Project a field that doesn't exist: {field_name}") - return _ProjectExpression(original, paths) - - def known_field_names(self) -> FrozenSet[path.Step]: - return frozenset(self._paths_map.keys()) + """Project all subfields of an expression.""" + + def __init__( + self, + origin: expression.Expression, + paths: Sequence[path.Path], + ): + super().__init__( + origin.is_repeated, + origin.type, + origin.schema_feature, + validate_step_format=origin.validate_step_format, + ) + self._paths_map = _group_paths_by_first_step(paths) + self._origin = origin + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._origin] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + if len(sources) != 1: + raise ValueError("Expected one source.") + return sources[0] + + def calculation_is_identity(self) -> bool: + return True + + def calculation_equal(self, expr: expression.Expression) -> bool: + return expr.calculation_is_identity() + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + paths = self._paths_map.get(field_name) + if paths is None: + return None + original = self._origin.get_child(field_name) + if original is None: + raise ValueError(f"Project a field that doesn't exist: {field_name}") + return _ProjectExpression(original, paths) + + def known_field_names(self) -> FrozenSet[path.Step]: + return frozenset(self._paths_map.keys()) diff --git a/struct2tensor/expression_impl/project_test.py b/struct2tensor/expression_impl/project_test.py index e8ac610..f6b4865 100644 --- a/struct2tensor/expression_impl/project_test.py +++ b/struct2tensor/expression_impl/project_test.py @@ -21,29 +21,27 @@ class ProjectTest(absltest.TestCase): - - def test_project(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor()) - projected = project.project( - expr, [path.Path(["user", "friends"]), - path.Path(["doc", "keep_me"])]) - self.assertIsNotNone( - projected.get_descendant(path.Path(["user", "friends"]))) - self.assertIsNotNone( - projected.get_descendant(path.Path(["doc", "keep_me"]))) - self.assertIsNone(projected.get_descendant(path.Path(["doc", "bar"]))) - - def test_project_lenient_format(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor_with_lenient_field_names(), - validate_step_format=False, - ) - projected = project.project( - expr, [path.Path(["doc", "keep_me/x"], validate_step_format=False)] - ) - self.assertLen(projected.get_known_descendants(), 3) + def test_project(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor() + ) + projected = project.project( + expr, [path.Path(["user", "friends"]), path.Path(["doc", "keep_me"])] + ) + self.assertIsNotNone(projected.get_descendant(path.Path(["user", "friends"]))) + self.assertIsNotNone(projected.get_descendant(path.Path(["doc", "keep_me"]))) + self.assertIsNone(projected.get_descendant(path.Path(["doc", "bar"]))) + + def test_project_lenient_format(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor_with_lenient_field_names(), + validate_step_format=False, + ) + projected = project.project( + expr, [path.Path(["doc", "keep_me/x"], validate_step_format=False)] + ) + self.assertLen(projected.get_known_descendants(), 3) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/promote.py b/struct2tensor/expression_impl/promote.py index b2a1f92..667aa3d 100644 --- a/struct2tensor/expression_impl/promote.py +++ b/struct2tensor/expression_impl/promote.py @@ -105,250 +105,268 @@ class PromoteExpression(expression.Leaf): - """A promoted leaf.""" - - def __init__(self, origin: expression.Expression, - origin_parent: expression.Expression): - - super().__init__( - origin.is_repeated or origin_parent.is_repeated, - origin.type, - schema_feature=_get_promote_schema_feature( - origin.schema_feature, origin_parent.schema_feature)) - self._origin = origin - self._origin_parent = origin_parent - if self.type is None: - raise ValueError("Can only promote a field") - if self._origin_parent.type is not None: - raise ValueError("origin_parent cannot be a field") - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._origin, self._origin_parent] - - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - [origin_value, origin_parent_value] = sources - if not isinstance(origin_value, prensor.LeafNodeTensor): - raise ValueError("origin_value must be a leaf") - if not isinstance(origin_parent_value, prensor.ChildNodeTensor): - raise ValueError("origin_parent_value must be a child node") - new_parent_index = tf.gather(origin_parent_value.parent_index, - origin_value.parent_index) - return prensor.LeafNodeTensor(new_parent_index, origin_value.values, - self.is_repeated) - - def calculation_is_identity(self) -> bool: - return False - - def calculation_equal(self, expr: expression.Expression) -> bool: - return isinstance(expr, PromoteExpression) + """A promoted leaf.""" + + def __init__( + self, origin: expression.Expression, origin_parent: expression.Expression + ): + super().__init__( + origin.is_repeated or origin_parent.is_repeated, + origin.type, + schema_feature=_get_promote_schema_feature( + origin.schema_feature, origin_parent.schema_feature + ), + ) + self._origin = origin + self._origin_parent = origin_parent + if self.type is None: + raise ValueError("Can only promote a field") + if self._origin_parent.type is not None: + raise ValueError("origin_parent cannot be a field") + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._origin, self._origin_parent] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + [origin_value, origin_parent_value] = sources + if not isinstance(origin_value, prensor.LeafNodeTensor): + raise ValueError("origin_value must be a leaf") + if not isinstance(origin_parent_value, prensor.ChildNodeTensor): + raise ValueError("origin_parent_value must be a child node") + new_parent_index = tf.gather( + origin_parent_value.parent_index, origin_value.parent_index + ) + return prensor.LeafNodeTensor( + new_parent_index, origin_value.values, self.is_repeated + ) + + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + return isinstance(expr, PromoteExpression) class PromoteChildExpression(expression.Expression): - """The root of the promoted sub tree.""" - - def __init__(self, origin: expression.Expression, - origin_parent: expression.Expression): - - super().__init__( - origin.is_repeated or origin_parent.is_repeated, - origin.type, - schema_feature=_get_promote_schema_feature( - origin.schema_feature, origin_parent.schema_feature - ), - validate_step_format=origin.validate_step_format, - ) - self._origin = origin - self._origin_parent = origin_parent - if self._origin_parent.type is not None: - raise ValueError("origin_parent cannot be a field") - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._origin, self._origin_parent] - - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - [origin_value, origin_parent_value] = sources - if not isinstance(origin_value, prensor.ChildNodeTensor): - raise ValueError("origin_value must be a child") - if not isinstance(origin_parent_value, prensor.ChildNodeTensor): - raise ValueError("origin_parent_value must be a child node") - new_parent_index = tf.gather(origin_parent_value.parent_index, - origin_value.parent_index) - return prensor.ChildNodeTensor(new_parent_index, self.is_repeated) - - def calculation_is_identity(self) -> bool: - return False - - def calculation_equal(self, expr: expression.Expression) -> bool: - return isinstance(expr, PromoteChildExpression) - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - return self._origin.get_child(field_name) - - def known_field_names(self) -> FrozenSet[path.Step]: - return self._origin.known_field_names() + """The root of the promoted sub tree.""" + + def __init__( + self, origin: expression.Expression, origin_parent: expression.Expression + ): + super().__init__( + origin.is_repeated or origin_parent.is_repeated, + origin.type, + schema_feature=_get_promote_schema_feature( + origin.schema_feature, origin_parent.schema_feature + ), + validate_step_format=origin.validate_step_format, + ) + self._origin = origin + self._origin_parent = origin_parent + if self._origin_parent.type is not None: + raise ValueError("origin_parent cannot be a field") + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._origin, self._origin_parent] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + [origin_value, origin_parent_value] = sources + if not isinstance(origin_value, prensor.ChildNodeTensor): + raise ValueError("origin_value must be a child") + if not isinstance(origin_parent_value, prensor.ChildNodeTensor): + raise ValueError("origin_parent_value must be a child node") + new_parent_index = tf.gather( + origin_parent_value.parent_index, origin_value.parent_index + ) + return prensor.ChildNodeTensor(new_parent_index, self.is_repeated) + + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + return isinstance(expr, PromoteChildExpression) + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + return self._origin.get_child(field_name) + + def known_field_names(self) -> FrozenSet[path.Step]: + return self._origin.known_field_names() def _lifecycle_stage_number(a) -> int: - """Return a number indicating the quality of the lifecycle stage. - - When there is more than one input field, the minimum lifecycle stage could be - used. - - Args: - a: an Optional[LifecycleStage] - - Returns: - An integer that corresponds to the lifecycle stage of 'a'. - """ - stages = [ - schema_pb2.LifecycleStage.DEPRECATED, schema_pb2.LifecycleStage.DISABLED, - schema_pb2.LifecycleStage.PLANNED, schema_pb2.LifecycleStage.ALPHA, - schema_pb2.LifecycleStage.DEBUG_ONLY, None, - schema_pb2.LifecycleStage.UNKNOWN_STAGE, schema_pb2.LifecycleStage.BETA, - schema_pb2.LifecycleStage.PRODUCTION - ] - return stages.index(a) + """Return a number indicating the quality of the lifecycle stage. + + When there is more than one input field, the minimum lifecycle stage could be + used. + + Args: + a: an Optional[LifecycleStage] + + Returns: + An integer that corresponds to the lifecycle stage of 'a'. + """ + stages = [ + schema_pb2.LifecycleStage.DEPRECATED, + schema_pb2.LifecycleStage.DISABLED, + schema_pb2.LifecycleStage.PLANNED, + schema_pb2.LifecycleStage.ALPHA, + schema_pb2.LifecycleStage.DEBUG_ONLY, + None, + schema_pb2.LifecycleStage.UNKNOWN_STAGE, + schema_pb2.LifecycleStage.BETA, + schema_pb2.LifecycleStage.PRODUCTION, + ] + return stages.index(a) def _min_lifecycle_stage(a, b): - """Get the minimum lifecycle stage. + """Get the minimum lifecycle stage. - Args: - a: an Optional[LifecycleStage] - b: an Optional[LifecycleStage] + Args: + a: an Optional[LifecycleStage] + b: an Optional[LifecycleStage] - Returns: - the minimal lifecycle stage. - """ - if _lifecycle_stage_number(b) < _lifecycle_stage_number(a): - return b - return a + Returns: + the minimal lifecycle stage. + """ + if _lifecycle_stage_number(b) < _lifecycle_stage_number(a): + return b + return a def _feature_is_dense(feature: schema_pb2.Feature) -> bool: - return (feature.presence.min_fraction == 1.0 and - feature.value_count.HasField("min") and - feature.value_count.HasField("max") and - feature.value_count.min == feature.value_count.max) + return ( + feature.presence.min_fraction == 1.0 + and feature.value_count.HasField("min") + and feature.value_count.HasField("max") + and feature.value_count.min == feature.value_count.max + ) def _copy_domain_info(origin: schema_pb2.Feature, dest: schema_pb2.Feature): - """Copy the domain info.""" - one_of_field_name = origin.WhichOneof("domain_info") - if one_of_field_name is None: - return - - origin_field = getattr(origin, one_of_field_name) - - field_descriptor = origin.DESCRIPTOR.fields_by_name.get(one_of_field_name) - if field_descriptor is None or field_descriptor.message_type is None: - setattr(dest, one_of_field_name, origin_field) - else: - dest_field = getattr(dest, one_of_field_name) - dest_field.CopyFrom(origin_field) - - -def _get_promote_schema_feature(original: Optional[schema_pb2.Feature], - parent: Optional[schema_pb2.Feature] - ) -> Optional[schema_pb2.Feature]: - """Generate the schema feature for the field resulting from promote. - - Note that promote results in the exact same number of values. - - Note that min_count is never propagated. - - Args: - original: the original feature - parent: the parent feature - - Returns: - the schema of the new field. - """ - if original is None or parent is None: - return None - result = schema_pb2.Feature() - result.lifecycle_stage = _min_lifecycle_stage(original.lifecycle_stage, - parent.lifecycle_stage) - result.type = original.type - if original.HasField("distribution_constraints"): - result.distribution_constraints.CopyFrom(original.distribution_constraints) - _copy_domain_info(original, result) - - if _feature_is_dense(parent): - parent_size = parent.value_count.min - if original.value_count.HasField("min"): - result.value_count.min = parent_size * original.value_count.min - if original.value_count.HasField("max"): - result.value_count.max = parent_size * original.value_count.max - if original.presence.HasField("min_fraction"): - if original.presence.min_fraction == 1: - result.presence.min_fraction = 1 - else: - result.presence.min_fraction = ( - original.presence.min_fraction / parent_size) - if original.presence.HasField("min_count"): - # If the parent is dense then the count can - # be reduced by the number of children. - # E.g. {{"a"},{"b"}},{{"c"},{"d"}},{{"e"},{"f"}} - # with a count of 6, with a parent size of 2 becomes: - # can become {"a","b"}, {"c", "d"}, {"e", "f"} - # which has a count of 3. - result.presence.min_count = original.presence.min_count // parent_size - return result - - -def _promote_impl(root: expression.Expression, p: path.Path, - new_field_name: path.Step - ) -> Tuple[expression.Expression, path.Path]: - """Promotes a path to be a child of its grandparent, and gives it a name. - - Args: - root: The root expression. - p: The path to promote. This can be the path to a leaf or child node. - new_field_name: The name of the promoted field. - - Returns: - An _AddPathsExpression that wraps a PromoteExpression. - """ - if len(p) < 2: - raise ValueError(f"Cannot do a promotion beyond the root: {str(p)}") - parent_path = p.get_parent() - grandparent_path = parent_path.get_parent() - - p_expression = root.get_descendant_or_error(p) - new_path = grandparent_path.get_child(new_field_name) - - if p_expression.is_leaf: - promote_expression_factory = PromoteExpression - else: - promote_expression_factory = PromoteChildExpression - - return expression_add.add_paths( - root, { - new_path: - promote_expression_factory( - p_expression, root.get_descendant_or_error(parent_path)) - }), new_path - - -def promote_anonymous(root: expression.Expression, - p: path.Path) -> Tuple[expression.Expression, path.Path]: - """Promote a path to be a new anonymous child of its grandparent.""" - return _promote_impl(root, p, path.get_anonymous_field()) - - -def promote(root: expression.Expression, p: path.Path, - new_field_name: path.Step) -> expression.Expression: - """Promote a path to be a child of its grandparent, and give it a name.""" - return _promote_impl(root, p, new_field_name)[0] + """Copy the domain info.""" + one_of_field_name = origin.WhichOneof("domain_info") + if one_of_field_name is None: + return + + origin_field = getattr(origin, one_of_field_name) + + field_descriptor = origin.DESCRIPTOR.fields_by_name.get(one_of_field_name) + if field_descriptor is None or field_descriptor.message_type is None: + setattr(dest, one_of_field_name, origin_field) + else: + dest_field = getattr(dest, one_of_field_name) + dest_field.CopyFrom(origin_field) + + +def _get_promote_schema_feature( + original: Optional[schema_pb2.Feature], parent: Optional[schema_pb2.Feature] +) -> Optional[schema_pb2.Feature]: + """Generate the schema feature for the field resulting from promote. + + Note that promote results in the exact same number of values. + + Note that min_count is never propagated. + + Args: + original: the original feature + parent: the parent feature + + Returns: + the schema of the new field. + """ + if original is None or parent is None: + return None + result = schema_pb2.Feature() + result.lifecycle_stage = _min_lifecycle_stage( + original.lifecycle_stage, parent.lifecycle_stage + ) + result.type = original.type + if original.HasField("distribution_constraints"): + result.distribution_constraints.CopyFrom(original.distribution_constraints) + _copy_domain_info(original, result) + + if _feature_is_dense(parent): + parent_size = parent.value_count.min + if original.value_count.HasField("min"): + result.value_count.min = parent_size * original.value_count.min + if original.value_count.HasField("max"): + result.value_count.max = parent_size * original.value_count.max + if original.presence.HasField("min_fraction"): + if original.presence.min_fraction == 1: + result.presence.min_fraction = 1 + else: + result.presence.min_fraction = ( + original.presence.min_fraction / parent_size + ) + if original.presence.HasField("min_count"): + # If the parent is dense then the count can + # be reduced by the number of children. + # E.g. {{"a"},{"b"}},{{"c"},{"d"}},{{"e"},{"f"}} + # with a count of 6, with a parent size of 2 becomes: + # can become {"a","b"}, {"c", "d"}, {"e", "f"} + # which has a count of 3. + result.presence.min_count = original.presence.min_count // parent_size + return result + + +def _promote_impl( + root: expression.Expression, p: path.Path, new_field_name: path.Step +) -> Tuple[expression.Expression, path.Path]: + """Promotes a path to be a child of its grandparent, and gives it a name. + + Args: + root: The root expression. + p: The path to promote. This can be the path to a leaf or child node. + new_field_name: The name of the promoted field. + + Returns: + An _AddPathsExpression that wraps a PromoteExpression. + """ + if len(p) < 2: + raise ValueError(f"Cannot do a promotion beyond the root: {str(p)}") + parent_path = p.get_parent() + grandparent_path = parent_path.get_parent() + + p_expression = root.get_descendant_or_error(p) + new_path = grandparent_path.get_child(new_field_name) + + if p_expression.is_leaf: + promote_expression_factory = PromoteExpression + else: + promote_expression_factory = PromoteChildExpression + + return expression_add.add_paths( + root, + { + new_path: promote_expression_factory( + p_expression, root.get_descendant_or_error(parent_path) + ) + }, + ), new_path + + +def promote_anonymous( + root: expression.Expression, p: path.Path +) -> Tuple[expression.Expression, path.Path]: + """Promote a path to be a new anonymous child of its grandparent.""" + return _promote_impl(root, p, path.get_anonymous_field()) + + +def promote( + root: expression.Expression, p: path.Path, new_field_name: path.Step +) -> expression.Expression: + """Promote a path to be a child of its grandparent, and give it a name.""" + return _promote_impl(root, p, new_field_name)[0] diff --git a/struct2tensor/expression_impl/promote_and_broadcast.py b/struct2tensor/expression_impl/promote_and_broadcast.py index 238aff3..20b8650 100644 --- a/struct2tensor/expression_impl/promote_and_broadcast.py +++ b/struct2tensor/expression_impl/promote_and_broadcast.py @@ -117,55 +117,64 @@ def promote_and_broadcast_anonymous( - root: expression.Expression, origin: path.Path, - new_parent: path.Path) -> Tuple[expression.Expression, path.Path]: - """Promotes then broadcasts the origin until its parent is new_parent.""" - least_common_ancestor = origin.get_least_common_ancestor(new_parent) - - new_expr, new_path = root, origin - while new_path.get_parent() != least_common_ancestor: - new_expr, new_path = promote.promote_anonymous(new_expr, new_path) - - while new_path.get_parent() != new_parent: - new_parent_step = new_parent.field_list[len(new_path) - 1] - new_expr, new_path = broadcast.broadcast_anonymous(new_expr, new_path, - new_parent_step) - - return new_expr, new_path - - -def _promote_and_broadcast_name(root: expression.Expression, origin: path.Path, - dest_path_parent: path.Path, - field_name: path.Step) -> expression.Expression: - new_root, anonymous_path = promote_and_broadcast_anonymous( - root, origin, dest_path_parent) - path_result = dest_path_parent.get_child(field_name) - return expression_add.add_paths( - new_root, {path_result: new_root.get_descendant_or_error(anonymous_path)}) - - -def promote_and_broadcast(root: expression.Expression, - path_dictionary: Mapping[path.Step, path.Path], - dest_path_parent: path.Path) -> expression.Expression: - """Promote and broadcast a set of paths to a particular location. - - Args: - root: the original expression. - path_dictionary: a map from destination fields to origin paths. - dest_path_parent: a map from destination strings. - - Returns: - A new expression, where all the origin paths are promoted and broadcast - until they are children of dest_path_parent. - """ - result_paths = {} - # Here, we branch out and create a different tree for each field that is - # promoted and broadcast. - for field_name, origin_path in path_dictionary.items(): - result_path = dest_path_parent.get_child(field_name) - new_root = _promote_and_broadcast_name(root, origin_path, dest_path_parent, - field_name) - result_paths[result_path] = new_root - # We create a new tree that has all of the generated fields from the older - # trees. - return expression_add.add_to(root, result_paths) + root: expression.Expression, origin: path.Path, new_parent: path.Path +) -> Tuple[expression.Expression, path.Path]: + """Promotes then broadcasts the origin until its parent is new_parent.""" + least_common_ancestor = origin.get_least_common_ancestor(new_parent) + + new_expr, new_path = root, origin + while new_path.get_parent() != least_common_ancestor: + new_expr, new_path = promote.promote_anonymous(new_expr, new_path) + + while new_path.get_parent() != new_parent: + new_parent_step = new_parent.field_list[len(new_path) - 1] + new_expr, new_path = broadcast.broadcast_anonymous( + new_expr, new_path, new_parent_step + ) + + return new_expr, new_path + + +def _promote_and_broadcast_name( + root: expression.Expression, + origin: path.Path, + dest_path_parent: path.Path, + field_name: path.Step, +) -> expression.Expression: + new_root, anonymous_path = promote_and_broadcast_anonymous( + root, origin, dest_path_parent + ) + path_result = dest_path_parent.get_child(field_name) + return expression_add.add_paths( + new_root, {path_result: new_root.get_descendant_or_error(anonymous_path)} + ) + + +def promote_and_broadcast( + root: expression.Expression, + path_dictionary: Mapping[path.Step, path.Path], + dest_path_parent: path.Path, +) -> expression.Expression: + """Promote and broadcast a set of paths to a particular location. + + Args: + root: the original expression. + path_dictionary: a map from destination fields to origin paths. + dest_path_parent: a map from destination strings. + + Returns: + A new expression, where all the origin paths are promoted and broadcast + until they are children of dest_path_parent. + """ + result_paths = {} + # Here, we branch out and create a different tree for each field that is + # promoted and broadcast. + for field_name, origin_path in path_dictionary.items(): + result_path = dest_path_parent.get_child(field_name) + new_root = _promote_and_broadcast_name( + root, origin_path, dest_path_parent, field_name + ) + result_paths[result_path] = new_root + # We create a new tree that has all of the generated fields from the older + # trees. + return expression_add.add_to(root, result_paths) diff --git a/struct2tensor/expression_impl/promote_and_broadcast_test.py b/struct2tensor/expression_impl/promote_and_broadcast_test.py index eafa6ce..d7f30e9 100644 --- a/struct2tensor/expression_impl/promote_and_broadcast_test.py +++ b/struct2tensor/expression_impl/promote_and_broadcast_test.py @@ -22,23 +22,24 @@ class PromoteAndBroadcastTest(absltest.TestCase): - - def test_promote_and_broadcast_anonymous(self): - """A basic promote and broadcast.""" - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - new_root, p = promote_and_broadcast.promote_and_broadcast_anonymous( - expr, path.Path(["user", "friends"]), path.Path(["doc"])) - new_field = new_root.get_descendant_or_error(p) - self.assertTrue(new_field.is_repeated) - self.assertEqual(new_field.type, tf.string) - self.assertTrue(new_field.is_leaf) - self.assertTrue(new_field.calculation_equal(new_field)) - self.assertFalse(new_field.calculation_equal(expr)) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertEqual(leaf_node.values.dtype, tf.string) - self.assertEqual(new_field.known_field_names(), frozenset()) + def test_promote_and_broadcast_anonymous(self): + """A basic promote and broadcast.""" + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + new_root, p = promote_and_broadcast.promote_and_broadcast_anonymous( + expr, path.Path(["user", "friends"]), path.Path(["doc"]) + ) + new_field = new_root.get_descendant_or_error(p) + self.assertTrue(new_field.is_repeated) + self.assertEqual(new_field.type, tf.string) + self.assertTrue(new_field.is_leaf) + self.assertTrue(new_field.calculation_equal(new_field)) + self.assertFalse(new_field.calculation_equal(expr)) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertEqual(leaf_node.values.dtype, tf.string) + self.assertEqual(new_field.known_field_names(), frozenset()) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/promote_test.py b/struct2tensor/expression_impl/promote_test.py index 71490ae..554c808 100644 --- a/struct2tensor/expression_impl/promote_test.py +++ b/struct2tensor/expression_impl/promote_test.py @@ -16,7 +16,7 @@ import tensorflow as tf from absl.testing import absltest from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) from tensorflow_metadata.proto.v0 import schema_pb2 @@ -26,514 +26,580 @@ class PromoteTest(absltest.TestCase): - - def assertLen(self, arr, expected_len): - self.assertEqual(len(arr), expected_len) # pylint:disable=g-generic-assert - - def test_promote_anonymous(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor()) - new_root, new_field = promote.promote_anonymous( - expr, path.Path(["user", "friends"])) - new_field = new_root.get_descendant_or_error(new_field) - self.assertTrue(new_field.is_repeated) - self.assertEqual(new_field.type, tf.string) - self.assertTrue(new_field.is_leaf) - self.assertFalse(new_field.calculation_is_identity()) - self.assertTrue(new_field.calculation_equal(new_field)) - self.assertFalse(new_field.calculation_equal(expr)) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertEqual(leaf_node.values.dtype, tf.string) - self.assertEqual(new_field.known_field_names(), frozenset()) - - sources = new_field.get_source_expressions() - self.assertLen(sources, 2) - self.assertIs( - expr.get_descendant_or_error(path.Path(["user", "friends"])), - sources[0]) - self.assertIs(expr.get_child_or_error("user"), sources[1]) - - def test_promote_with_schema(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()).apply_schema( - prensor_test_util.create_big_prensor_schema()) - - new_root, new_field = promote.promote_anonymous( - expr, path.Path(["user", "friends"])) - new_field = new_root.get_descendant_or_error(new_field) - new_schema_feature = new_field.schema_feature - self.assertIsNotNone(new_schema_feature) - self.assertEqual(new_schema_feature.string_domain.value[0], "a") - - def test_promote_with_schema_dense_parent(self): - s = prensor_test_util.create_big_prensor_schema() - feature_dict = {feature.name: feature for feature in s.feature} - user_feature = feature_dict["user"] - user_feature.value_count.min = 3 - user_feature.value_count.max = 3 - user_feature.presence.min_fraction = 1 - user_feature.lifecycle_stage = schema_pb2.LifecycleStage.ALPHA - - user_dict = { - feature.name: feature for feature in user_feature.struct_domain.feature - } - friends_feature = user_dict["friends"] - friends_feature.value_count.min = 2 - friends_feature.value_count.max = 2 - friends_feature.presence.min_fraction = 1 - friends_feature.presence.min_count = 10 - friends_feature.lifecycle_stage = schema_pb2.LifecycleStage.BETA - friends_feature.distribution_constraints.min_domain_mass = 0.5 - - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()).apply_schema(s) - - new_root, new_field = promote.promote_anonymous( - expr, path.Path(["user", "friends"])) - new_field = new_root.get_descendant_or_error(new_field) - new_schema_feature = new_field.schema_feature - self.assertIsNotNone(new_schema_feature) - self.assertEqual(new_schema_feature.string_domain.value[0], "a") - self.assertEqual(new_schema_feature.value_count.max, 6) - self.assertEqual(new_schema_feature.value_count.min, 6) - self.assertEqual(new_schema_feature.presence.min_fraction, 1) - self.assertEqual(new_schema_feature.presence.min_count, 3) - self.assertEqual(new_schema_feature.lifecycle_stage, - schema_pb2.LifecycleStage.ALPHA) - self.assertEqual( - new_schema_feature.distribution_constraints.min_domain_mass, 0.5) - - def test_lifecycle_stage(self): - # Stages have the following priority, from lowest to highest: - # schema_pb2.LifecycleStage.DEPRECATED - # schema_pb2.LifecycleStage.DISABLED - # schema_pb2.LifecycleStage.PLANNED, - # schema_pb2.LifecycleStage.ALPHA - # schema_pb2.LifecycleStage.DEBUG_ONLY, - # None - # schema_pb2.LifecycleStage.UNKNOWN_STAGE, - # schema_pb2.LifecycleStage.BETA - # schema_pb2.LifecycleStage.PRODUCTION - def _check_lifecycle_stage(a, b): - s = prensor_test_util.create_big_prensor_schema() - feature_dict = {feature.name: feature for feature in s.feature} - user_feature = feature_dict["user"] - if a is not None: - user_feature.lifecycle_stage = a - - user_dict = { - feature.name: feature - for feature in user_feature.struct_domain.feature - } - friends_feature = user_dict["friends"] - if b is not None: - friends_feature.lifecycle_stage = b - - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()).apply_schema(s) - - new_root, new_field = promote.promote_anonymous( - expr, path.Path(["user", "friends"])) - new_field = new_root.get_descendant_or_error(new_field) - return new_field.schema_feature.lifecycle_stage - - self.assertEqual( - schema_pb2.LifecycleStage.DEPRECATED, - _check_lifecycle_stage(schema_pb2.LifecycleStage.DEPRECATED, - schema_pb2.LifecycleStage.DISABLED)) - self.assertEqual( - schema_pb2.LifecycleStage.DEPRECATED, - _check_lifecycle_stage(schema_pb2.LifecycleStage.DISABLED, - schema_pb2.LifecycleStage.DEPRECATED)) - - self.assertEqual( - schema_pb2.LifecycleStage.DISABLED, - _check_lifecycle_stage(schema_pb2.LifecycleStage.PLANNED, - schema_pb2.LifecycleStage.DISABLED)) - - self.assertEqual( - schema_pb2.LifecycleStage.DEPRECATED, - _check_lifecycle_stage(schema_pb2.LifecycleStage.PLANNED, - schema_pb2.LifecycleStage.DEPRECATED)) - - self.assertEqual( - schema_pb2.LifecycleStage.PLANNED, - _check_lifecycle_stage(schema_pb2.LifecycleStage.PLANNED, - schema_pb2.LifecycleStage.ALPHA)) - self.assertEqual( - schema_pb2.LifecycleStage.PLANNED, - _check_lifecycle_stage(schema_pb2.LifecycleStage.ALPHA, - schema_pb2.LifecycleStage.PLANNED)) - - self.assertEqual( - schema_pb2.LifecycleStage.ALPHA, - _check_lifecycle_stage(schema_pb2.LifecycleStage.DEBUG_ONLY, - schema_pb2.LifecycleStage.ALPHA)) - self.assertEqual( - schema_pb2.LifecycleStage.ALPHA, - _check_lifecycle_stage(schema_pb2.LifecycleStage.ALPHA, - schema_pb2.LifecycleStage.DEBUG_ONLY)) - - self.assertEqual( - schema_pb2.LifecycleStage.DEBUG_ONLY, - _check_lifecycle_stage(schema_pb2.LifecycleStage.DEBUG_ONLY, None)) - - self.assertEqual( - schema_pb2.LifecycleStage.DEBUG_ONLY, - _check_lifecycle_stage(None, schema_pb2.LifecycleStage.DEBUG_ONLY)) - - # None looks like UNKNOWN_STAGE. - self.assertEqual( - schema_pb2.LifecycleStage.UNKNOWN_STAGE, - _check_lifecycle_stage(None, schema_pb2.LifecycleStage.UNKNOWN_STAGE)) - self.assertEqual( - schema_pb2.LifecycleStage.UNKNOWN_STAGE, - _check_lifecycle_stage(schema_pb2.LifecycleStage.UNKNOWN_STAGE, None)) - - self.assertEqual( - schema_pb2.LifecycleStage.UNKNOWN_STAGE, - _check_lifecycle_stage(schema_pb2.LifecycleStage.BETA, - schema_pb2.LifecycleStage.UNKNOWN_STAGE)) - self.assertEqual( - schema_pb2.LifecycleStage.UNKNOWN_STAGE, - _check_lifecycle_stage(schema_pb2.LifecycleStage.UNKNOWN_STAGE, - schema_pb2.LifecycleStage.BETA)) - - self.assertEqual( - schema_pb2.LifecycleStage.BETA, - _check_lifecycle_stage(schema_pb2.LifecycleStage.BETA, - schema_pb2.LifecycleStage.PRODUCTION)) - self.assertEqual( - schema_pb2.LifecycleStage.BETA, - _check_lifecycle_stage(schema_pb2.LifecycleStage.PRODUCTION, - schema_pb2.LifecycleStage.BETA)) - - def test_promote_with_schema_dense_fraction(self): - """Test when min_fraction is not 1.""" - s = prensor_test_util.create_big_prensor_schema() - feature_dict = {feature.name: feature for feature in s.feature} - user_feature = feature_dict["user"] - user_feature.value_count.min = 3 - user_feature.value_count.max = 3 - user_feature.presence.min_fraction = 1 - - user_dict = { - feature.name: feature for feature in user_feature.struct_domain.feature - } - friends_feature = user_dict["friends"] - friends_feature.presence.min_fraction = 0.9 - - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()).apply_schema(s) - - new_root, new_field = promote.promote_anonymous( - expr, path.Path(["user", "friends"])) - new_field = new_root.get_descendant_or_error(new_field) - new_schema_feature = new_field.schema_feature - self.assertIsNotNone(new_schema_feature) - self.assertEqual(new_schema_feature.presence.min_fraction, 0.3) - - def test_promote_optional_child_of_repeated(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor()) - new_root, new_field = promote.promote_anonymous( - expr, path.Path(["doc", "keep_me"])) - new_expr = new_root.get_descendant_or_error(new_field) - self.assertTrue(new_expr.is_repeated) - - def test_promote(self): - """Tests promote.promote(...), and indirectly tests set_path.""" - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor()) - new_root = promote.promote(expr, path.Path(["user", "friends"]), - "new_field") - new_field = new_root.get_child_or_error("new_field") - self.assertIsNotNone(new_field) - self.assertTrue(new_field.is_repeated) - self.assertEqual(new_field.type, tf.string) - self.assertTrue(new_field.is_leaf) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertEqual(leaf_node.values.dtype, tf.string) - self.assertEqual(new_field.known_field_names(), frozenset()) - - def test_promote_substructure(self): - """Tests promote.promote(...) of substructure.""" - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_deep_prensor()) - new_root = promote.promote(expr, path.Path(["event", "doc"]), "new_field") - - new_field = new_root.get_child_or_error("new_field") - self.assertIsNotNone(new_field) - self.assertTrue(new_field.is_repeated) - self.assertEqual(new_field.known_field_names(), - frozenset(["bar", "keep_me"])) - - bar_expr = new_field.get_child_or_error("bar") - self.assertIsNotNone(bar_expr) - self.assertTrue(bar_expr.is_repeated) - self.assertEqual(bar_expr.type, tf.string) - self.assertTrue(bar_expr.is_leaf) - - keep_me_expr = new_field.get_child_or_error("keep_me") - self.assertIsNotNone(keep_me_expr) - self.assertFalse(keep_me_expr.is_repeated) - self.assertEqual(keep_me_expr.type, tf.bool) - self.assertTrue(keep_me_expr.is_leaf) - - child_node = expression_test_util.calculate_value_slowly(new_field) - self.assertEqual(child_node.size, 3) - self.assertTrue(child_node.is_repeated) - - bar_node = expression_test_util.calculate_value_slowly(bar_expr) - self.assertEqual(bar_node.values.dtype, tf.string) - - keep_me_node = expression_test_util.calculate_value_slowly(keep_me_expr) - self.assertEqual(keep_me_node.values.dtype, tf.bool) - - def test_promote_substructure_then_leaf(self): - """Tests expr.promote(...) of substructure and then a leaf.""" - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_deep_prensor()) - new_root = (expr - .promote(path.Path(["event", "doc"]), "new_field") - .promote(path.Path(["new_field", "bar"]), "new_bar")) - - new_bar = new_root.get_child_or_error("new_bar") - self.assertIsNotNone(new_bar) - self.assertTrue(new_bar.is_repeated) - self.assertEqual(new_bar.type, tf.string) - self.assertTrue(new_bar.is_leaf) - - new_field_bar = new_root.get_descendant_or_error( - path.Path(["new_field", "bar"])) - self.assertIsNotNone(new_field_bar) - self.assertTrue(new_bar.is_repeated) - self.assertEqual(new_bar.type, tf.string) - self.assertTrue(new_bar.is_leaf) - - new_field_keep_me = new_root.get_descendant_or_error( - path.Path(["new_field", "keep_me"])) - self.assertIsNotNone(new_field_keep_me) - self.assertFalse(new_field_keep_me.is_repeated) - self.assertEqual(new_field_keep_me.type, tf.bool) - self.assertTrue(new_field_keep_me.is_leaf) - - bar_node = expression_test_util.calculate_value_slowly(new_bar) - self.assertEqual(bar_node.values.dtype, tf.string) - - new_field_bar_node = expression_test_util.calculate_value_slowly( - new_field_bar) - self.assertEqual(new_field_bar_node.values.dtype, tf.string) - - new_field_keep_me_node = expression_test_util.calculate_value_slowly( - new_field_keep_me) - self.assertEqual(new_field_keep_me_node.values.dtype, tf.bool) - - def test_promote_leaf_then_substructure(self): - """Tests expr.promote(...) of leaf and then a substructure.""" - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_four_layer_prensor()) - new_root = ( - expr - .promote(path.Path(["event", "doc", "nested_child", "bar"]), "new_bar") - .promote(path.Path(["event", "doc"]), "new_doc")) - - new_doc = new_root.get_child_or_error("new_doc") - self.assertIsNotNone(new_doc) - self.assertTrue(new_doc.is_repeated) - self.assertEqual(new_doc.known_field_names(), - frozenset(["nested_child", "new_bar"])) - - new_bar_expr = new_doc.get_child_or_error("new_bar") - self.assertIsNotNone(new_bar_expr) - self.assertTrue(new_bar_expr.is_repeated) - self.assertEqual(new_bar_expr.type, tf.string) - self.assertTrue(new_bar_expr.is_leaf) - - nested_child_expr = new_doc.get_child_or_error("nested_child") - self.assertIsNotNone(nested_child_expr) - self.assertTrue(nested_child_expr.is_repeated) - self.assertEqual(nested_child_expr.known_field_names(), - frozenset(["bar", "keep_me"])) - - bar_expr = nested_child_expr.get_child_or_error("bar") - self.assertIsNotNone(bar_expr) - self.assertTrue(bar_expr.is_repeated) - self.assertEqual(bar_expr.type, tf.string) - self.assertTrue(bar_expr.is_leaf) - - keep_me_expr = nested_child_expr.get_child_or_error("keep_me") - self.assertIsNotNone(keep_me_expr) - self.assertFalse(keep_me_expr.is_repeated) - self.assertEqual(keep_me_expr.type, tf.bool) - self.assertTrue(keep_me_expr.is_leaf) - - bar_node = expression_test_util.calculate_value_slowly(new_bar_expr) - self.assertEqual(bar_node.values.dtype, tf.string) + def assertLen(self, arr, expected_len): + self.assertEqual(len(arr), expected_len) # pylint:disable=g-generic-assert + + def test_promote_anonymous(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor() + ) + new_root, new_field = promote.promote_anonymous( + expr, path.Path(["user", "friends"]) + ) + new_field = new_root.get_descendant_or_error(new_field) + self.assertTrue(new_field.is_repeated) + self.assertEqual(new_field.type, tf.string) + self.assertTrue(new_field.is_leaf) + self.assertFalse(new_field.calculation_is_identity()) + self.assertTrue(new_field.calculation_equal(new_field)) + self.assertFalse(new_field.calculation_equal(expr)) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertEqual(leaf_node.values.dtype, tf.string) + self.assertEqual(new_field.known_field_names(), frozenset()) + + sources = new_field.get_source_expressions() + self.assertLen(sources, 2) + self.assertIs( + expr.get_descendant_or_error(path.Path(["user", "friends"])), sources[0] + ) + self.assertIs(expr.get_child_or_error("user"), sources[1]) + + def test_promote_with_schema(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ).apply_schema(prensor_test_util.create_big_prensor_schema()) + + new_root, new_field = promote.promote_anonymous( + expr, path.Path(["user", "friends"]) + ) + new_field = new_root.get_descendant_or_error(new_field) + new_schema_feature = new_field.schema_feature + self.assertIsNotNone(new_schema_feature) + self.assertEqual(new_schema_feature.string_domain.value[0], "a") + + def test_promote_with_schema_dense_parent(self): + s = prensor_test_util.create_big_prensor_schema() + feature_dict = {feature.name: feature for feature in s.feature} + user_feature = feature_dict["user"] + user_feature.value_count.min = 3 + user_feature.value_count.max = 3 + user_feature.presence.min_fraction = 1 + user_feature.lifecycle_stage = schema_pb2.LifecycleStage.ALPHA + + user_dict = { + feature.name: feature for feature in user_feature.struct_domain.feature + } + friends_feature = user_dict["friends"] + friends_feature.value_count.min = 2 + friends_feature.value_count.max = 2 + friends_feature.presence.min_fraction = 1 + friends_feature.presence.min_count = 10 + friends_feature.lifecycle_stage = schema_pb2.LifecycleStage.BETA + friends_feature.distribution_constraints.min_domain_mass = 0.5 + + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ).apply_schema(s) + + new_root, new_field = promote.promote_anonymous( + expr, path.Path(["user", "friends"]) + ) + new_field = new_root.get_descendant_or_error(new_field) + new_schema_feature = new_field.schema_feature + self.assertIsNotNone(new_schema_feature) + self.assertEqual(new_schema_feature.string_domain.value[0], "a") + self.assertEqual(new_schema_feature.value_count.max, 6) + self.assertEqual(new_schema_feature.value_count.min, 6) + self.assertEqual(new_schema_feature.presence.min_fraction, 1) + self.assertEqual(new_schema_feature.presence.min_count, 3) + self.assertEqual( + new_schema_feature.lifecycle_stage, schema_pb2.LifecycleStage.ALPHA + ) + self.assertEqual( + new_schema_feature.distribution_constraints.min_domain_mass, 0.5 + ) + + def test_lifecycle_stage(self): + # Stages have the following priority, from lowest to highest: + # schema_pb2.LifecycleStage.DEPRECATED + # schema_pb2.LifecycleStage.DISABLED + # schema_pb2.LifecycleStage.PLANNED, + # schema_pb2.LifecycleStage.ALPHA + # schema_pb2.LifecycleStage.DEBUG_ONLY, + # None + # schema_pb2.LifecycleStage.UNKNOWN_STAGE, + # schema_pb2.LifecycleStage.BETA + # schema_pb2.LifecycleStage.PRODUCTION + def _check_lifecycle_stage(a, b): + s = prensor_test_util.create_big_prensor_schema() + feature_dict = {feature.name: feature for feature in s.feature} + user_feature = feature_dict["user"] + if a is not None: + user_feature.lifecycle_stage = a + + user_dict = { + feature.name: feature for feature in user_feature.struct_domain.feature + } + friends_feature = user_dict["friends"] + if b is not None: + friends_feature.lifecycle_stage = b + + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ).apply_schema(s) + + new_root, new_field = promote.promote_anonymous( + expr, path.Path(["user", "friends"]) + ) + new_field = new_root.get_descendant_or_error(new_field) + return new_field.schema_feature.lifecycle_stage + + self.assertEqual( + schema_pb2.LifecycleStage.DEPRECATED, + _check_lifecycle_stage( + schema_pb2.LifecycleStage.DEPRECATED, schema_pb2.LifecycleStage.DISABLED + ), + ) + self.assertEqual( + schema_pb2.LifecycleStage.DEPRECATED, + _check_lifecycle_stage( + schema_pb2.LifecycleStage.DISABLED, schema_pb2.LifecycleStage.DEPRECATED + ), + ) + + self.assertEqual( + schema_pb2.LifecycleStage.DISABLED, + _check_lifecycle_stage( + schema_pb2.LifecycleStage.PLANNED, schema_pb2.LifecycleStage.DISABLED + ), + ) + + self.assertEqual( + schema_pb2.LifecycleStage.DEPRECATED, + _check_lifecycle_stage( + schema_pb2.LifecycleStage.PLANNED, schema_pb2.LifecycleStage.DEPRECATED + ), + ) + + self.assertEqual( + schema_pb2.LifecycleStage.PLANNED, + _check_lifecycle_stage( + schema_pb2.LifecycleStage.PLANNED, schema_pb2.LifecycleStage.ALPHA + ), + ) + self.assertEqual( + schema_pb2.LifecycleStage.PLANNED, + _check_lifecycle_stage( + schema_pb2.LifecycleStage.ALPHA, schema_pb2.LifecycleStage.PLANNED + ), + ) + + self.assertEqual( + schema_pb2.LifecycleStage.ALPHA, + _check_lifecycle_stage( + schema_pb2.LifecycleStage.DEBUG_ONLY, schema_pb2.LifecycleStage.ALPHA + ), + ) + self.assertEqual( + schema_pb2.LifecycleStage.ALPHA, + _check_lifecycle_stage( + schema_pb2.LifecycleStage.ALPHA, schema_pb2.LifecycleStage.DEBUG_ONLY + ), + ) + + self.assertEqual( + schema_pb2.LifecycleStage.DEBUG_ONLY, + _check_lifecycle_stage(schema_pb2.LifecycleStage.DEBUG_ONLY, None), + ) + + self.assertEqual( + schema_pb2.LifecycleStage.DEBUG_ONLY, + _check_lifecycle_stage(None, schema_pb2.LifecycleStage.DEBUG_ONLY), + ) + + # None looks like UNKNOWN_STAGE. + self.assertEqual( + schema_pb2.LifecycleStage.UNKNOWN_STAGE, + _check_lifecycle_stage(None, schema_pb2.LifecycleStage.UNKNOWN_STAGE), + ) + self.assertEqual( + schema_pb2.LifecycleStage.UNKNOWN_STAGE, + _check_lifecycle_stage(schema_pb2.LifecycleStage.UNKNOWN_STAGE, None), + ) + + self.assertEqual( + schema_pb2.LifecycleStage.UNKNOWN_STAGE, + _check_lifecycle_stage( + schema_pb2.LifecycleStage.BETA, schema_pb2.LifecycleStage.UNKNOWN_STAGE + ), + ) + self.assertEqual( + schema_pb2.LifecycleStage.UNKNOWN_STAGE, + _check_lifecycle_stage( + schema_pb2.LifecycleStage.UNKNOWN_STAGE, schema_pb2.LifecycleStage.BETA + ), + ) + + self.assertEqual( + schema_pb2.LifecycleStage.BETA, + _check_lifecycle_stage( + schema_pb2.LifecycleStage.BETA, schema_pb2.LifecycleStage.PRODUCTION + ), + ) + self.assertEqual( + schema_pb2.LifecycleStage.BETA, + _check_lifecycle_stage( + schema_pb2.LifecycleStage.PRODUCTION, schema_pb2.LifecycleStage.BETA + ), + ) + + def test_promote_with_schema_dense_fraction(self): + """Test when min_fraction is not 1.""" + s = prensor_test_util.create_big_prensor_schema() + feature_dict = {feature.name: feature for feature in s.feature} + user_feature = feature_dict["user"] + user_feature.value_count.min = 3 + user_feature.value_count.max = 3 + user_feature.presence.min_fraction = 1 + + user_dict = { + feature.name: feature for feature in user_feature.struct_domain.feature + } + friends_feature = user_dict["friends"] + friends_feature.presence.min_fraction = 0.9 + + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ).apply_schema(s) + + new_root, new_field = promote.promote_anonymous( + expr, path.Path(["user", "friends"]) + ) + new_field = new_root.get_descendant_or_error(new_field) + new_schema_feature = new_field.schema_feature + self.assertIsNotNone(new_schema_feature) + self.assertEqual(new_schema_feature.presence.min_fraction, 0.3) + + def test_promote_optional_child_of_repeated(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor() + ) + new_root, new_field = promote.promote_anonymous( + expr, path.Path(["doc", "keep_me"]) + ) + new_expr = new_root.get_descendant_or_error(new_field) + self.assertTrue(new_expr.is_repeated) + + def test_promote(self): + """Tests promote.promote(...), and indirectly tests set_path.""" + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor() + ) + new_root = promote.promote(expr, path.Path(["user", "friends"]), "new_field") + new_field = new_root.get_child_or_error("new_field") + self.assertIsNotNone(new_field) + self.assertTrue(new_field.is_repeated) + self.assertEqual(new_field.type, tf.string) + self.assertTrue(new_field.is_leaf) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertEqual(leaf_node.values.dtype, tf.string) + self.assertEqual(new_field.known_field_names(), frozenset()) + + def test_promote_substructure(self): + """Tests promote.promote(...) of substructure.""" + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_deep_prensor() + ) + new_root = promote.promote(expr, path.Path(["event", "doc"]), "new_field") + + new_field = new_root.get_child_or_error("new_field") + self.assertIsNotNone(new_field) + self.assertTrue(new_field.is_repeated) + self.assertEqual(new_field.known_field_names(), frozenset(["bar", "keep_me"])) + + bar_expr = new_field.get_child_or_error("bar") + self.assertIsNotNone(bar_expr) + self.assertTrue(bar_expr.is_repeated) + self.assertEqual(bar_expr.type, tf.string) + self.assertTrue(bar_expr.is_leaf) + + keep_me_expr = new_field.get_child_or_error("keep_me") + self.assertIsNotNone(keep_me_expr) + self.assertFalse(keep_me_expr.is_repeated) + self.assertEqual(keep_me_expr.type, tf.bool) + self.assertTrue(keep_me_expr.is_leaf) + + child_node = expression_test_util.calculate_value_slowly(new_field) + self.assertEqual(child_node.size, 3) + self.assertTrue(child_node.is_repeated) + + bar_node = expression_test_util.calculate_value_slowly(bar_expr) + self.assertEqual(bar_node.values.dtype, tf.string) + + keep_me_node = expression_test_util.calculate_value_slowly(keep_me_expr) + self.assertEqual(keep_me_node.values.dtype, tf.bool) + + def test_promote_substructure_then_leaf(self): + """Tests expr.promote(...) of substructure and then a leaf.""" + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_deep_prensor() + ) + new_root = expr.promote(path.Path(["event", "doc"]), "new_field").promote( + path.Path(["new_field", "bar"]), "new_bar" + ) + + new_bar = new_root.get_child_or_error("new_bar") + self.assertIsNotNone(new_bar) + self.assertTrue(new_bar.is_repeated) + self.assertEqual(new_bar.type, tf.string) + self.assertTrue(new_bar.is_leaf) + + new_field_bar = new_root.get_descendant_or_error( + path.Path(["new_field", "bar"]) + ) + self.assertIsNotNone(new_field_bar) + self.assertTrue(new_bar.is_repeated) + self.assertEqual(new_bar.type, tf.string) + self.assertTrue(new_bar.is_leaf) + + new_field_keep_me = new_root.get_descendant_or_error( + path.Path(["new_field", "keep_me"]) + ) + self.assertIsNotNone(new_field_keep_me) + self.assertFalse(new_field_keep_me.is_repeated) + self.assertEqual(new_field_keep_me.type, tf.bool) + self.assertTrue(new_field_keep_me.is_leaf) + + bar_node = expression_test_util.calculate_value_slowly(new_bar) + self.assertEqual(bar_node.values.dtype, tf.string) + + new_field_bar_node = expression_test_util.calculate_value_slowly(new_field_bar) + self.assertEqual(new_field_bar_node.values.dtype, tf.string) + + new_field_keep_me_node = expression_test_util.calculate_value_slowly( + new_field_keep_me + ) + self.assertEqual(new_field_keep_me_node.values.dtype, tf.bool) + + def test_promote_leaf_then_substructure(self): + """Tests expr.promote(...) of leaf and then a substructure.""" + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_four_layer_prensor() + ) + new_root = expr.promote( + path.Path(["event", "doc", "nested_child", "bar"]), "new_bar" + ).promote(path.Path(["event", "doc"]), "new_doc") + + new_doc = new_root.get_child_or_error("new_doc") + self.assertIsNotNone(new_doc) + self.assertTrue(new_doc.is_repeated) + self.assertEqual( + new_doc.known_field_names(), frozenset(["nested_child", "new_bar"]) + ) + + new_bar_expr = new_doc.get_child_or_error("new_bar") + self.assertIsNotNone(new_bar_expr) + self.assertTrue(new_bar_expr.is_repeated) + self.assertEqual(new_bar_expr.type, tf.string) + self.assertTrue(new_bar_expr.is_leaf) + + nested_child_expr = new_doc.get_child_or_error("nested_child") + self.assertIsNotNone(nested_child_expr) + self.assertTrue(nested_child_expr.is_repeated) + self.assertEqual( + nested_child_expr.known_field_names(), frozenset(["bar", "keep_me"]) + ) + + bar_expr = nested_child_expr.get_child_or_error("bar") + self.assertIsNotNone(bar_expr) + self.assertTrue(bar_expr.is_repeated) + self.assertEqual(bar_expr.type, tf.string) + self.assertTrue(bar_expr.is_leaf) + + keep_me_expr = nested_child_expr.get_child_or_error("keep_me") + self.assertIsNotNone(keep_me_expr) + self.assertFalse(keep_me_expr.is_repeated) + self.assertEqual(keep_me_expr.type, tf.bool) + self.assertTrue(keep_me_expr.is_leaf) + + bar_node = expression_test_util.calculate_value_slowly(new_bar_expr) + self.assertEqual(bar_node.values.dtype, tf.string) @test_util.run_all_in_graph_and_eager_modes class PromoteValuesTest(tf.test.TestCase): - - def test_promote_and_calculate(self): - """Tests promoting a leaf on a nested tree.""" - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor()) - new_root, new_path = promote.promote_anonymous( - expr, path.Path(["user", "friends"])) - new_field = new_root.get_descendant_or_error(new_path) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 1, 2]) - self.assertAllEqual(leaf_node.values, [b"a", b"b", b"c", b"d", b"e"]) - - def test_promote_and_calculate_substructure(self): - """Tests promoting substructure on a tree with depth of 4.""" - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_four_layer_prensor()) - new_root, new_path = promote.promote_anonymous( - expr, path.Path(["event", "doc", "nested_child"])) - new_nested_child = new_root.get_descendant_or_error(new_path) - bar_expr = new_root.get_descendant_or_error(new_path.get_child("bar")) - keep_me_expr = new_root.get_descendant_or_error( - new_path.get_child("keep_me")) - - # the promoted nested_child's parent index is changed. - nested_child_node = expression_test_util.calculate_value_slowly( - new_nested_child) - self.assertAllEqual(nested_child_node.parent_index, [0, 1, 1, 1]) - self.assertTrue(nested_child_node.is_repeated) - - # bar's parent index should be unchanged. - bar_node = expression_test_util.calculate_value_slowly(bar_expr) - self.assertAllEqual(bar_node.parent_index, [0, 1, 1, 2]) - self.assertAllEqual(bar_node.values, [b"a", b"b", b"c", b"d"]) - self.assertTrue(bar_node.is_repeated) - - # keep_me's parent index should be unchanged. - keep_me_node = expression_test_util.calculate_value_slowly(keep_me_expr) - self.assertAllEqual(keep_me_node.parent_index, [0, 1]) - self.assertAllEqual(keep_me_node.values, [False, True]) - self.assertFalse(keep_me_node.is_repeated) - - def test_promote_and_calculate_substructure_then_leaf(self): - """Tests promoting of substructure and then a leaf.""" - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_four_layer_prensor()) - new_root, new_nested_child_path = promote.promote_anonymous( - expr, path.Path(["event", "doc", "nested_child"])) - new_root, new_bar_path = promote.promote_anonymous( - new_root, new_nested_child_path.get_child("bar")) - - # the promoted nested_child's parent index is changed. - new_nested_child = new_root.get_descendant_or_error(new_nested_child_path) - nested_child_node = expression_test_util.calculate_value_slowly( - new_nested_child) - self.assertAllEqual(nested_child_node.parent_index, [0, 1, 1, 1]) - self.assertTrue(nested_child_node.is_repeated) - - # promoted bar's parent index is changed. - new_bar = new_root.get_descendant_or_error(new_bar_path) - bar_node = expression_test_util.calculate_value_slowly(new_bar) - self.assertAllEqual(bar_node.parent_index, [0, 1, 1, 1]) - self.assertAllEqual(bar_node.values, [b"a", b"b", b"c", b"d"]) - self.assertTrue(bar_node.is_repeated) - - # bar's parent index should be unchanged. - nested_child_bar = new_root.get_descendant_or_error( - new_nested_child_path.get_child("bar")) - nested_child_bar_node = expression_test_util.calculate_value_slowly( - nested_child_bar) - self.assertAllEqual(nested_child_bar_node.parent_index, [0, 1, 1, 2]) - self.assertAllEqual(nested_child_bar_node.values, [b"a", b"b", b"c", b"d"]) - self.assertTrue(nested_child_bar_node.is_repeated) - - # keep_me's parent index should be unchanged. - nested_child_keep_me = new_root.get_descendant_or_error( - new_nested_child_path.get_child("keep_me")) - nested_child_keep_me_node = expression_test_util.calculate_value_slowly( - nested_child_keep_me) - self.assertAllEqual(nested_child_keep_me_node.parent_index, [0, 1]) - self.assertAllEqual(nested_child_keep_me_node.values, [False, True]) - self.assertFalse(nested_child_keep_me_node.is_repeated) - - def test_promote_and_calculate_leaf_then_substructure(self): - """Tests promoting of leaf and then a substructure.""" - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_four_layer_prensor()) - new_root, new_bar_path = promote.promote_anonymous( - expr, path.Path(["event", "doc", "nested_child", "bar"])) - new_root, new_path = promote.promote_anonymous(new_root, - path.Path(["event", "doc"])) - - new_doc = new_root.get_descendant_or_error(new_path) - new_bar = new_root.get_descendant_or_error( - new_path.concat(new_bar_path.suffix(2))) - bar_expr = new_root.get_descendant_or_error( - new_path.concat(path.Path(["nested_child", "bar"]))) - keep_me_expr = new_root.get_descendant_or_error( - new_path.concat(path.Path(["nested_child", "keep_me"]))) - - new_doc_node = expression_test_util.calculate_value_slowly(new_doc) - self.assertAllEqual(new_doc_node.parent_index, [0, 1, 1]) - self.assertTrue(new_doc_node.is_repeated) - - # new_bar's parent index is changed (from the first promote). - # The second promote should not change new_bar's parent index. - new_bar_node = expression_test_util.calculate_value_slowly(new_bar) - self.assertAllEqual(new_bar_node.parent_index, [0, 1, 1, 1]) - self.assertAllEqual(new_bar_node.values, [b"a", b"b", b"c", b"d"]) - self.assertTrue(new_bar_node.is_repeated) - - # bar's parent index should be unchanged. - bar_node = expression_test_util.calculate_value_slowly(bar_expr) - self.assertAllEqual(bar_node.parent_index, [0, 1, 1, 2]) - self.assertAllEqual(bar_node.values, [b"a", b"b", b"c", b"d"]) - self.assertTrue(bar_node.is_repeated) - - # keep_me's parent index should be unchanged. - keep_me_node = expression_test_util.calculate_value_slowly(keep_me_expr) - self.assertAllEqual(keep_me_node.parent_index, [0, 1]) - self.assertAllEqual(keep_me_node.values, [False, True]) - self.assertFalse(keep_me_node.is_repeated) - - def test_promote_substructure_with_schema(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_deep_prensor()).apply_schema( - prensor_test_util.create_deep_prensor_schema()) - - original_schema = expr.get_descendant_or_error(path.Path(["event", "doc" - ])).schema_feature - - new_root, new_field_path = promote.promote_anonymous( - expr, path.Path(["event", "doc"])) - new_field = new_root.get_descendant_or_error(new_field_path) - new_schema_feature = new_field.schema_feature - self.assertIsNotNone(new_schema_feature) - - # The struct_domain of this feature should not be changed. - self.assertProtoEquals(new_schema_feature.struct_domain, - original_schema.struct_domain) - - bar_schema = new_root.get_descendant_or_error( - new_field_path.concat(path.Path(["bar"]))).schema_feature - self.assertIsNotNone(bar_schema) - self.assertEqual(bar_schema.string_domain.value[0], "a") - - keep_me_schema = new_root.get_descendant_or_error( - new_field_path.concat(path.Path(["keep_me"]))).schema_feature - self.assertIsNotNone(keep_me_schema) - self.assertEqual(keep_me_schema.presence.min_count, 1) + def test_promote_and_calculate(self): + """Tests promoting a leaf on a nested tree.""" + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor() + ) + new_root, new_path = promote.promote_anonymous( + expr, path.Path(["user", "friends"]) + ) + new_field = new_root.get_descendant_or_error(new_path) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 1, 2]) + self.assertAllEqual(leaf_node.values, [b"a", b"b", b"c", b"d", b"e"]) + + def test_promote_and_calculate_substructure(self): + """Tests promoting substructure on a tree with depth of 4.""" + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_four_layer_prensor() + ) + new_root, new_path = promote.promote_anonymous( + expr, path.Path(["event", "doc", "nested_child"]) + ) + new_nested_child = new_root.get_descendant_or_error(new_path) + bar_expr = new_root.get_descendant_or_error(new_path.get_child("bar")) + keep_me_expr = new_root.get_descendant_or_error(new_path.get_child("keep_me")) + + # the promoted nested_child's parent index is changed. + nested_child_node = expression_test_util.calculate_value_slowly( + new_nested_child + ) + self.assertAllEqual(nested_child_node.parent_index, [0, 1, 1, 1]) + self.assertTrue(nested_child_node.is_repeated) + + # bar's parent index should be unchanged. + bar_node = expression_test_util.calculate_value_slowly(bar_expr) + self.assertAllEqual(bar_node.parent_index, [0, 1, 1, 2]) + self.assertAllEqual(bar_node.values, [b"a", b"b", b"c", b"d"]) + self.assertTrue(bar_node.is_repeated) + + # keep_me's parent index should be unchanged. + keep_me_node = expression_test_util.calculate_value_slowly(keep_me_expr) + self.assertAllEqual(keep_me_node.parent_index, [0, 1]) + self.assertAllEqual(keep_me_node.values, [False, True]) + self.assertFalse(keep_me_node.is_repeated) + + def test_promote_and_calculate_substructure_then_leaf(self): + """Tests promoting of substructure and then a leaf.""" + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_four_layer_prensor() + ) + new_root, new_nested_child_path = promote.promote_anonymous( + expr, path.Path(["event", "doc", "nested_child"]) + ) + new_root, new_bar_path = promote.promote_anonymous( + new_root, new_nested_child_path.get_child("bar") + ) + + # the promoted nested_child's parent index is changed. + new_nested_child = new_root.get_descendant_or_error(new_nested_child_path) + nested_child_node = expression_test_util.calculate_value_slowly( + new_nested_child + ) + self.assertAllEqual(nested_child_node.parent_index, [0, 1, 1, 1]) + self.assertTrue(nested_child_node.is_repeated) + + # promoted bar's parent index is changed. + new_bar = new_root.get_descendant_or_error(new_bar_path) + bar_node = expression_test_util.calculate_value_slowly(new_bar) + self.assertAllEqual(bar_node.parent_index, [0, 1, 1, 1]) + self.assertAllEqual(bar_node.values, [b"a", b"b", b"c", b"d"]) + self.assertTrue(bar_node.is_repeated) + + # bar's parent index should be unchanged. + nested_child_bar = new_root.get_descendant_or_error( + new_nested_child_path.get_child("bar") + ) + nested_child_bar_node = expression_test_util.calculate_value_slowly( + nested_child_bar + ) + self.assertAllEqual(nested_child_bar_node.parent_index, [0, 1, 1, 2]) + self.assertAllEqual(nested_child_bar_node.values, [b"a", b"b", b"c", b"d"]) + self.assertTrue(nested_child_bar_node.is_repeated) + + # keep_me's parent index should be unchanged. + nested_child_keep_me = new_root.get_descendant_or_error( + new_nested_child_path.get_child("keep_me") + ) + nested_child_keep_me_node = expression_test_util.calculate_value_slowly( + nested_child_keep_me + ) + self.assertAllEqual(nested_child_keep_me_node.parent_index, [0, 1]) + self.assertAllEqual(nested_child_keep_me_node.values, [False, True]) + self.assertFalse(nested_child_keep_me_node.is_repeated) + + def test_promote_and_calculate_leaf_then_substructure(self): + """Tests promoting of leaf and then a substructure.""" + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_four_layer_prensor() + ) + new_root, new_bar_path = promote.promote_anonymous( + expr, path.Path(["event", "doc", "nested_child", "bar"]) + ) + new_root, new_path = promote.promote_anonymous( + new_root, path.Path(["event", "doc"]) + ) + + new_doc = new_root.get_descendant_or_error(new_path) + new_bar = new_root.get_descendant_or_error( + new_path.concat(new_bar_path.suffix(2)) + ) + bar_expr = new_root.get_descendant_or_error( + new_path.concat(path.Path(["nested_child", "bar"])) + ) + keep_me_expr = new_root.get_descendant_or_error( + new_path.concat(path.Path(["nested_child", "keep_me"])) + ) + + new_doc_node = expression_test_util.calculate_value_slowly(new_doc) + self.assertAllEqual(new_doc_node.parent_index, [0, 1, 1]) + self.assertTrue(new_doc_node.is_repeated) + + # new_bar's parent index is changed (from the first promote). + # The second promote should not change new_bar's parent index. + new_bar_node = expression_test_util.calculate_value_slowly(new_bar) + self.assertAllEqual(new_bar_node.parent_index, [0, 1, 1, 1]) + self.assertAllEqual(new_bar_node.values, [b"a", b"b", b"c", b"d"]) + self.assertTrue(new_bar_node.is_repeated) + + # bar's parent index should be unchanged. + bar_node = expression_test_util.calculate_value_slowly(bar_expr) + self.assertAllEqual(bar_node.parent_index, [0, 1, 1, 2]) + self.assertAllEqual(bar_node.values, [b"a", b"b", b"c", b"d"]) + self.assertTrue(bar_node.is_repeated) + + # keep_me's parent index should be unchanged. + keep_me_node = expression_test_util.calculate_value_slowly(keep_me_expr) + self.assertAllEqual(keep_me_node.parent_index, [0, 1]) + self.assertAllEqual(keep_me_node.values, [False, True]) + self.assertFalse(keep_me_node.is_repeated) + + def test_promote_substructure_with_schema(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_deep_prensor() + ).apply_schema(prensor_test_util.create_deep_prensor_schema()) + + original_schema = expr.get_descendant_or_error( + path.Path(["event", "doc"]) + ).schema_feature + + new_root, new_field_path = promote.promote_anonymous( + expr, path.Path(["event", "doc"]) + ) + new_field = new_root.get_descendant_or_error(new_field_path) + new_schema_feature = new_field.schema_feature + self.assertIsNotNone(new_schema_feature) + + # The struct_domain of this feature should not be changed. + self.assertProtoEquals( + new_schema_feature.struct_domain, original_schema.struct_domain + ) + + bar_schema = new_root.get_descendant_or_error( + new_field_path.concat(path.Path(["bar"])) + ).schema_feature + self.assertIsNotNone(bar_schema) + self.assertEqual(bar_schema.string_domain.value[0], "a") + + keep_me_schema = new_root.get_descendant_or_error( + new_field_path.concat(path.Path(["keep_me"])) + ).schema_feature + self.assertIsNotNone(keep_me_schema) + self.assertEqual(keep_me_schema.presence.min_count, 1) def test_promote_lenient_format(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor_with_lenient_field_names(), - validate_step_format=False, - ) - new_root, _ = promote.promote_anonymous(expr, path.Path(["doc"])) - self.assertFalse(new_root.validate_step_format) - self.assertLen(new_root.get_known_descendants(), 1) + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor_with_lenient_field_names(), + validate_step_format=False, + ) + new_root, _ = promote.promote_anonymous(expr, path.Path(["doc"])) + self.assertFalse(new_root.validate_step_format) + self.assertLen(new_root.get_known_descendants(), 1) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/proto.py b/struct2tensor/expression_impl/proto.py index 28faff3..5a4369e 100644 --- a/struct2tensor/expression_impl/proto.py +++ b/struct2tensor/expression_impl/proto.py @@ -20,15 +20,15 @@ import abc from typing import ( - Callable, - FrozenSet, - Mapping, - Optional, - Sequence, - Set, - Tuple, - Union, - cast, + Callable, + FrozenSet, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Union, + cast, ) import tensorflow as tf @@ -49,59 +49,62 @@ def is_proto_expression(expr: expression.Expression) -> bool: - """Returns true if an expression is a ProtoExpression.""" - return isinstance( - expr, (_ProtoRootExpression, _ProtoChildExpression, _ProtoLeafExpression)) + """Returns true if an expression is a ProtoExpression.""" + return isinstance( + expr, (_ProtoRootExpression, _ProtoChildExpression, _ProtoLeafExpression) + ) def create_expression_from_file_descriptor_set( tensor_of_protos: tf.Tensor, proto_name: ProtoFullName, file_descriptor_set: descriptor_pb2.FileDescriptorSet, - message_format: str = "binary") -> expression.Expression: - """Create an expression from a 1D tensor of serialized protos. - - Args: - tensor_of_protos: 1D tensor of serialized protos. - proto_name: fully qualified name (e.g. "some.package.SomeProto") of the - proto in `tensor_of_protos`. - file_descriptor_set: The FileDescriptorSet proto containing `proto_name`'s - and all its dependencies' FileDescriptorProto. Note that if file1 imports - file2, then file2's FileDescriptorProto must precede file1's in - file_descriptor_set.file. - message_format: Indicates the format of the protocol buffer: is one of - 'text' or 'binary'. - - Returns: - An expression. - """ - pool = DescriptorPool() - for f in file_descriptor_set.file: - # This method raises if f's dependencies have not been added. - pool.Add(f) - - # This method raises if proto not found. - desc = pool.FindMessageTypeByName(proto_name) - - return create_expression_from_proto(tensor_of_protos, desc, message_format) + message_format: str = "binary", +) -> expression.Expression: + """Create an expression from a 1D tensor of serialized protos. + + Args: + tensor_of_protos: 1D tensor of serialized protos. + proto_name: fully qualified name (e.g. "some.package.SomeProto") of the + proto in `tensor_of_protos`. + file_descriptor_set: The FileDescriptorSet proto containing `proto_name`'s + and all its dependencies' FileDescriptorProto. Note that if file1 imports + file2, then file2's FileDescriptorProto must precede file1's in + file_descriptor_set.file. + message_format: Indicates the format of the protocol buffer: is one of + 'text' or 'binary'. + + Returns: + An expression. + """ + pool = DescriptorPool() + for f in file_descriptor_set.file: + # This method raises if f's dependencies have not been added. + pool.Add(f) + + # This method raises if proto not found. + desc = pool.FindMessageTypeByName(proto_name) + + return create_expression_from_proto(tensor_of_protos, desc, message_format) def create_expression_from_proto( tensor_of_protos: tf.Tensor, desc: descriptor.Descriptor, - message_format: str = "binary") -> expression.Expression: - """Create an expression from a 1D tensor of serialized protos. + message_format: str = "binary", +) -> expression.Expression: + """Create an expression from a 1D tensor of serialized protos. - Args: - tensor_of_protos: 1D tensor of serialized protos. - desc: a descriptor of protos in tensor of protos. - message_format: Indicates the format of the protocol buffer: is one of - 'text' or 'binary'. + Args: + tensor_of_protos: 1D tensor of serialized protos. + desc: a descriptor of protos in tensor of protos. + message_format: Indicates the format of the protocol buffer: is one of + 'text' or 'binary'. - Returns: - An expression. - """ - return _ProtoRootExpression(desc, tensor_of_protos, message_format) + Returns: + An expression. + """ + return _ProtoRootExpression(desc, tensor_of_protos, message_format) # The function signature expected by `created_transformed_field`. @@ -118,492 +121,544 @@ def create_expression_from_proto( def create_transformed_field( - expr: expression.Expression, source_path: path.CoercableToPath, - dest_field: StrStep, transform_fn: TransformFn) -> expression.Expression: - """Create an expression that transforms serialized proto tensors. - - The transform_fn argument should take the form: - - def transform_fn(parent_indices, values): - ... - return (transformed_parent_indices, transformed_values) - - Given: - - parent_indices: an int64 vector of non-decreasing parent message indices. - - values: a string vector of serialized protos having the same shape as - `parent_indices`. - `transform_fn` must return new parent indices and serialized values encoding - the same proto message as the passed in `values`. These two vectors must - have the same size, but it need not be the same as the input arguments. - - Note: - If CalculateOptions.use_string_view (set at calculate time, thus this - Expression cannot know beforehand) is True, `values` passed to - `transform_fn` are string views pointing all the way back to the original - input tensor (of serialized root protos). And `transform_fn` must maintain - such views and avoid creating new values that are either not string views - into the root protos or self-owned strings. This is because downstream - decoding ops will still produce string views referring into its input - (which are string views into the root proto) and they will only hold a - reference to the original, root proto tensor, keeping it alive. So the input - tensor may get destroyed after the decoding op. - - In short, you can do element-wise transforms to `values`, but can't mutate - the contents of elements in `values` or create new elements. - - To lift this restriction, a decoding op must be told to hold a reference - of the input tensors of all its upstream decoding ops. - - - Args: - expr: a source expression containing `source_path`. - source_path: the path to the field to reverse. - dest_field: the name of the newly created field. This field will be a - sibling of the field identified by `source_path`. - transform_fn: a callable that accepts parent_indices and serialized proto - values and returns a posibly modified parent_indices and values. Note that - when CalcuateOptions.use_string_view is set, transform_fn should not have - any stateful side effecting uses of serialized proto inputs. Doing so - could cause segfaults as the backing string tensor lifetime is not - guaranteed when the side effecting operations are run. - - Returns: - An expression. - - Raises: - ValueError: if the source path is not a proto message field. - """ - source_path = path.create_path(source_path) - source_expr = expr.get_descendant_or_error(source_path) - if not isinstance(source_expr, _ProtoChildExpression): - raise ValueError( - f"Expected _ProtoChildExpression for field {str(source_path)}, but found {source_expr}.") - - if isinstance(source_expr, _TransformProtoChildExpression): - # In order to be able to propagate fields needed for parsing, the source - # expression of _TransformProtoChildExpression must always be the original - # _ProtoChildExpression before any transformation. This means that two - # sequentially applied _TransformProtoChildExpression would have the same - # source and would apply the transformation to the source directly, instead - # of one transform operating on the output of the other. - # To work around this, the user supplied transform function is wrapped to - # first call the source's transform function. - # The downside of this approach is that the initial transform may be - # applied redundantly if there are other expressions derived directly - # from it. - def final_transform(parent_indices: tf.Tensor, - values: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: - parent_indices, values = source_expr.transform_fn(parent_indices, values) - return transform_fn(parent_indices, values) - else: - final_transform = transform_fn - - transformed_expr = _TransformProtoChildExpression( - parent=source_expr._parent, # pylint: disable=protected-access - desc=source_expr._desc, # pylint: disable=protected-access - is_repeated=source_expr.is_repeated, - name_as_field=source_expr.name_as_field, - transform_fn=final_transform, - backing_str_tensor=source_expr._backing_str_tensor) # pylint: disable=protected-access - dest_path = source_path.get_parent().get_child(dest_field) - return expression_add.add_paths(expr, {dest_path: transformed_expr}) + expr: expression.Expression, + source_path: path.CoercableToPath, + dest_field: StrStep, + transform_fn: TransformFn, +) -> expression.Expression: + """Create an expression that transforms serialized proto tensors. + + The transform_fn argument should take the form: + + def transform_fn(parent_indices, values): + ... + return (transformed_parent_indices, transformed_values) + + Given: + - parent_indices: an int64 vector of non-decreasing parent message indices. + - values: a string vector of serialized protos having the same shape as + `parent_indices`. + `transform_fn` must return new parent indices and serialized values encoding + the same proto message as the passed in `values`. These two vectors must + have the same size, but it need not be the same as the input arguments. + + Note: + If CalculateOptions.use_string_view (set at calculate time, thus this + Expression cannot know beforehand) is True, `values` passed to + `transform_fn` are string views pointing all the way back to the original + input tensor (of serialized root protos). And `transform_fn` must maintain + such views and avoid creating new values that are either not string views + into the root protos or self-owned strings. This is because downstream + decoding ops will still produce string views referring into its input + (which are string views into the root proto) and they will only hold a + reference to the original, root proto tensor, keeping it alive. So the input + tensor may get destroyed after the decoding op. + + In short, you can do element-wise transforms to `values`, but can't mutate + the contents of elements in `values` or create new elements. + + To lift this restriction, a decoding op must be told to hold a reference + of the input tensors of all its upstream decoding ops. + + + Args: + expr: a source expression containing `source_path`. + source_path: the path to the field to reverse. + dest_field: the name of the newly created field. This field will be a + sibling of the field identified by `source_path`. + transform_fn: a callable that accepts parent_indices and serialized proto + values and returns a posibly modified parent_indices and values. Note that + when CalcuateOptions.use_string_view is set, transform_fn should not have + any stateful side effecting uses of serialized proto inputs. Doing so + could cause segfaults as the backing string tensor lifetime is not + guaranteed when the side effecting operations are run. + + Returns: + An expression. + + Raises: + ValueError: if the source path is not a proto message field. + """ + source_path = path.create_path(source_path) + source_expr = expr.get_descendant_or_error(source_path) + if not isinstance(source_expr, _ProtoChildExpression): + raise ValueError( + f"Expected _ProtoChildExpression for field {str(source_path)}, but found {source_expr}." + ) + + if isinstance(source_expr, _TransformProtoChildExpression): + # In order to be able to propagate fields needed for parsing, the source + # expression of _TransformProtoChildExpression must always be the original + # _ProtoChildExpression before any transformation. This means that two + # sequentially applied _TransformProtoChildExpression would have the same + # source and would apply the transformation to the source directly, instead + # of one transform operating on the output of the other. + # To work around this, the user supplied transform function is wrapped to + # first call the source's transform function. + # The downside of this approach is that the initial transform may be + # applied redundantly if there are other expressions derived directly + # from it. + def final_transform( + parent_indices: tf.Tensor, values: tf.Tensor + ) -> Tuple[tf.Tensor, tf.Tensor]: + parent_indices, values = source_expr.transform_fn(parent_indices, values) + return transform_fn(parent_indices, values) + else: + final_transform = transform_fn + + transformed_expr = _TransformProtoChildExpression( + parent=source_expr._parent, # pylint: disable=protected-access + desc=source_expr._desc, # pylint: disable=protected-access + is_repeated=source_expr.is_repeated, + name_as_field=source_expr.name_as_field, + transform_fn=final_transform, + backing_str_tensor=source_expr._backing_str_tensor, + ) # pylint: disable=protected-access + dest_path = source_path.get_parent().get_child(dest_field) + return expression_add.add_paths(expr, {dest_path: transformed_expr}) class _ProtoRootNodeTensor(prensor.RootNodeTensor): - """The value of the root node. + """The value of the root node. - This not only contains the normal size information, but also information - needed by its children. + This not only contains the normal size information, but also information + needed by its children. - In particular: - 1. Any needed regular fields are included. - 2. Any needed extended fields are included. - 3. Any needed map fields are included. - 4. if this is an Any proto, any needed casted fields are included. + In particular: + 1. Any needed regular fields are included. + 2. Any needed extended fields are included. + 3. Any needed map fields are included. + 4. if this is an Any proto, any needed casted fields are included. - """ + """ - def __init__(self, size: tf.Tensor, - fields: Mapping[StrStep, struct2tensor_ops._ParsedField]): - super().__init__(size) - self.fields = fields + def __init__( + self, size: tf.Tensor, fields: Mapping[StrStep, struct2tensor_ops._ParsedField] + ): + super().__init__(size) + self.fields = fields class _ProtoChildNodeTensor(prensor.ChildNodeTensor): - """The value of a child node. + """The value of a child node. - This not only contains the normal parent_index information, but also - information needed by its children. + This not only contains the normal parent_index information, but also + information needed by its children. - In particular: - 1. Any needed regular fields are included. - 2. Any needed extended fields are included. - 3. Any needed map fields are included. - 4. if this is an Any proto, any needed casted fields are included. - """ + In particular: + 1. Any needed regular fields are included. + 2. Any needed extended fields are included. + 3. Any needed map fields are included. + 4. if this is an Any proto, any needed casted fields are included. + """ - def __init__(self, parent_index: tf.Tensor, is_repeated: bool, - fields: Mapping[StrStep, struct2tensor_ops._ParsedField]): - super().__init__(parent_index, is_repeated) - self.fields = fields + def __init__( + self, + parent_index: tf.Tensor, + is_repeated: bool, + fields: Mapping[StrStep, struct2tensor_ops._ParsedField], + ): + super().__init__(parent_index, is_repeated) + self.fields = fields _ParentProtoNodeTensor = Union[_ProtoRootNodeTensor, _ProtoChildNodeTensor] class _AbstractProtoChildExpression(expression.Expression): - """A child or leaf proto expression.""" - - def __init__(self, parent: "_ParentProtoExpression", name_as_field: StrStep, - is_repeated: bool, my_type: Optional[tf.DType], - backing_str_tensor: Optional[tf.Tensor]): - super().__init__(is_repeated, my_type) - self._parent = parent - self._name_as_field = name_as_field - - self._backing_str_tensor = backing_str_tensor - - @property - def name_as_field(self) -> StrStep: - return self._name_as_field - - def get_needed_fields(self, expr: expression.Expression) -> Sequence[StrStep]: - return [self._name_as_field] - - def get_path(self) -> path.Path: - """Returns the path to the root of the proto.""" - return self._parent.get_path().get_child(self.name_as_field) - - def get_proto_source(self) -> Tuple[tf.Tensor, descriptor.Descriptor]: - """Returns the proto root.""" - return self._parent.get_proto_source() - - def get_source_expressions(self) -> Sequence[expression.Expression]: - # In order to parse this proto, you need to parse its parent. - return [self._parent] - - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - [parent_value] = sources - if isinstance(parent_value, (_ProtoRootNodeTensor, _ProtoChildNodeTensor)): - parsed_field = parent_value.fields.get(self.name_as_field) - if parsed_field is None: - raise ValueError(f"Cannot find {str(self)} in {str(parent_value)}") - return self.calculate_from_parsed_field(parsed_field, destinations, - options) - raise ValueError("Not a _ParentProtoNodeTensor: " + str(type(parent_value))) - - @abc.abstractmethod - def calculate_from_parsed_field( - self, parsed_field: struct2tensor_ops._ParsedField, # pylint: disable=protected-access - destinations: Sequence[expression.Expression], - options: calculate_options.Options) -> prensor.NodeTensor: - """Calculate the NodeTensor given the parsed fields requested from a parent. - - Args: - parsed_field: the parsed field from name_as_field. - destinations: the destination of the expression. - options: calculate options. - - Returns: - A node tensor for this node. - """ - raise NotImplementedError() - - def calculation_is_identity(self) -> bool: - return False + """A child or leaf proto expression.""" + + def __init__( + self, + parent: "_ParentProtoExpression", + name_as_field: StrStep, + is_repeated: bool, + my_type: Optional[tf.DType], + backing_str_tensor: Optional[tf.Tensor], + ): + super().__init__(is_repeated, my_type) + self._parent = parent + self._name_as_field = name_as_field + + self._backing_str_tensor = backing_str_tensor + + @property + def name_as_field(self) -> StrStep: + return self._name_as_field + + def get_needed_fields(self, expr: expression.Expression) -> Sequence[StrStep]: + return [self._name_as_field] + + def get_path(self) -> path.Path: + """Returns the path to the root of the proto.""" + return self._parent.get_path().get_child(self.name_as_field) + + def get_proto_source(self) -> Tuple[tf.Tensor, descriptor.Descriptor]: + """Returns the proto root.""" + return self._parent.get_proto_source() + + def get_source_expressions(self) -> Sequence[expression.Expression]: + # In order to parse this proto, you need to parse its parent. + return [self._parent] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + [parent_value] = sources + if isinstance(parent_value, (_ProtoRootNodeTensor, _ProtoChildNodeTensor)): + parsed_field = parent_value.fields.get(self.name_as_field) + if parsed_field is None: + raise ValueError(f"Cannot find {str(self)} in {str(parent_value)}") + return self.calculate_from_parsed_field(parsed_field, destinations, options) + raise ValueError("Not a _ParentProtoNodeTensor: " + str(type(parent_value))) + + @abc.abstractmethod + def calculate_from_parsed_field( + self, + parsed_field: struct2tensor_ops._ParsedField, # pylint: disable=protected-access + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + ) -> prensor.NodeTensor: + """Calculate the NodeTensor given the parsed fields requested from a parent. + + Args: + parsed_field: the parsed field from name_as_field. + destinations: the destination of the expression. + options: calculate options. + + Returns: + A node tensor for this node. + """ + raise NotImplementedError() + + def calculation_is_identity(self) -> bool: + return False class _ProtoLeafExpression(_AbstractProtoChildExpression): - """Represents parsing a leaf field.""" + """Represents parsing a leaf field.""" + + def __init__( + self, + parent: "_ParentProtoExpression", + desc: descriptor.FieldDescriptor, + name_as_field: path.Step, + ): + """Initialize a proto leaf expression. + + Args: + parent: the parent of the expression. + desc: the field descriptor of the expression name_as_field. + name_as_field: the name of the field. + """ + super().__init__( + parent, + name_as_field, + desc.label == descriptor.FieldDescriptor.LABEL_REPEATED, + struct2tensor_ops._get_dtype_from_cpp_type(desc.cpp_type), + None, + ) # pylint: disable=protected-access + # TODO(martinz): make _get_dtype_from_cpp_type public. + self._field_descriptor = desc + + def calculate_from_parsed_field( + self, + parsed_field: struct2tensor_ops._ParsedField, # pylint: disable=protected-access + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + ) -> prensor.NodeTensor: + return prensor.LeafNodeTensor( + parsed_field.index, parsed_field.value, self.is_repeated + ) + + def calculation_equal(self, expr: expression.Expression) -> bool: + # pylint: disable=protected-access + return ( + isinstance(expr, _ProtoLeafExpression) + and self._field_descriptor == expr._field_descriptor + and self.name_as_field == expr.name_as_field + ) + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + return None + + def known_field_names(self) -> FrozenSet[path.Step]: + return frozenset() + + def __str__(self) -> str: + return f"_ProtoLeafExpression: {self.name_as_field} from {self._parent}" - def __init__(self, parent: "_ParentProtoExpression", - desc: descriptor.FieldDescriptor, name_as_field: path.Step): - """Initialize a proto leaf expression. - Args: - parent: the parent of the expression. - desc: the field descriptor of the expression name_as_field. - name_as_field: the name of the field. +class _ProtoChildExpression(_AbstractProtoChildExpression): + """An expression representing a proto submessage. + + Supports: + A standard submessage. + An extension submessage. + A protobuf.Any submessage. + A proto map submessage. + Also supports having fields of the above types. """ - super().__init__(parent, name_as_field, - desc.label == descriptor.FieldDescriptor.LABEL_REPEATED, - struct2tensor_ops._get_dtype_from_cpp_type(desc.cpp_type), - None) # pylint: disable=protected-access - # TODO(martinz): make _get_dtype_from_cpp_type public. - self._field_descriptor = desc - - def calculate_from_parsed_field( - self, parsed_field: struct2tensor_ops._ParsedField, # pylint: disable=protected-access - destinations: Sequence[expression.Expression], - options: calculate_options.Options) -> prensor.NodeTensor: - return prensor.LeafNodeTensor(parsed_field.index, parsed_field.value, - self.is_repeated) - def calculation_equal(self, expr: expression.Expression) -> bool: - # pylint: disable=protected-access - return (isinstance(expr, _ProtoLeafExpression) and - self._field_descriptor == expr._field_descriptor and - self.name_as_field == expr.name_as_field) + def __init__( + self, + parent: "_ParentProtoExpression", + desc: descriptor.Descriptor, + is_repeated: bool, + name_as_field: StrStep, + backing_str_tensor: Optional[tf.Tensor], + ): + """Initialize a _ProtoChildExpression. + + This does not take a field descriptor so it can represent syntactic sugar + fields such as Any and Maps. + + Args: + parent: the parent. + desc: the message descriptor of the submessage represented by this + expression. + is_repeated: whether the field is repeated. + name_as_field: the name of the field. + backing_str_tensor: a string tensor representing the root serialized + proto. This is passed to keep string_views of the tensor valid for + all children of the root expression + """ + super().__init__(parent, name_as_field, is_repeated, None, backing_str_tensor) + self._desc = desc + + def calculate_from_parsed_field( + self, + parsed_field: struct2tensor_ops._ParsedField, # pylint:disable=protected-access + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + ) -> prensor.NodeTensor: + needed_fields = _get_needed_fields(destinations) + backing_str_tensor = None + if options.use_string_view: + backing_str_tensor = self._backing_str_tensor + fields = parse_message_level_ex.parse_message_level_ex( + parsed_field.value, + self._desc, + needed_fields, + backing_str_tensor=backing_str_tensor, + honor_proto3_optional_semantics=options.experimental_honor_proto3_optional_semantics, + ) + return _ProtoChildNodeTensor(parsed_field.index, self.is_repeated, fields) + + def calculation_equal(self, expr: expression.Expression) -> bool: + # Ensure that we're dealing with the _ProtoChildExpression and not any + # of its subclasses. + if type(expr) != _ProtoChildExpression: # noqa: E721 + return False + expr = cast(_ProtoChildExpression, expr) # Keep pytype happy. + return ( + self._desc == expr._desc # pylint: disable=protected-access + and self.name_as_field == expr.name_as_field + ) + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + return _get_child(self, self._desc, field_name, self._backing_str_tensor) + + def known_field_names(self) -> FrozenSet[path.Step]: + return _known_field_names_from_descriptor(self._desc) + + def __str__(self) -> str: + return f"_ProtoChildExpression: name_as_field: {str(self.name_as_field)} desc: {str(self._desc.full_name)} from {self._parent}" - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - return None - def known_field_names(self) -> FrozenSet[path.Step]: - return frozenset() - - def __str__(self) -> str: - return f"_ProtoLeafExpression: {self.name_as_field} from {self._parent}" - - -class _ProtoChildExpression(_AbstractProtoChildExpression): - """An expression representing a proto submessage. - - Supports: - A standard submessage. - An extension submessage. - A protobuf.Any submessage. - A proto map submessage. - Also supports having fields of the above types. - """ +class _TransformProtoChildExpression(_ProtoChildExpression): + """Transforms the parent indices and values prior to parsing.""" + + def __init__( + self, + parent: "_ParentProtoExpression", + desc: descriptor.Descriptor, + is_repeated: bool, + name_as_field: StrStep, + transform_fn: TransformFn, + backing_str_tensor: Optional[tf.Tensor], + ): + super().__init__(parent, desc, is_repeated, name_as_field, backing_str_tensor) + self._transform_fn = transform_fn + + @property + def transform_fn(self): + return self._transform_fn + + def calculate_from_parsed_field( + self, + parsed_field: struct2tensor_ops._ParsedField, # pylint:disable=protected-access + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + ) -> prensor.NodeTensor: + needed_fields = _get_needed_fields(destinations) + transformed_parent_indices, transformed_values = self._transform_fn( + parsed_field.index, parsed_field.value + ) + backing_str_tensor = None + if options.use_string_view: + backing_str_tensor = self._backing_str_tensor + fields = parse_message_level_ex.parse_message_level_ex( + transformed_values, + self._desc, + needed_fields, + backing_str_tensor=backing_str_tensor, + honor_proto3_optional_semantics=options.experimental_honor_proto3_optional_semantics, + ) + return _ProtoChildNodeTensor( + transformed_parent_indices, self.is_repeated, fields + ) + + def calculation_equal(self, expr: expression.Expression) -> bool: + return ( + isinstance(expr, _TransformProtoChildExpression) + and self._desc == expr._desc # pylint: disable=protected-access + and self.name_as_field == expr.name_as_field + and self.transform_fn is expr.transform_fn + ) + + def __str__(self) -> str: + return f"_TransformProtoChildExpression: name_as_field: {str(self.name_as_field)} desc: {str(self._desc.full_name)} from {self._parent}" - def __init__(self, parent: "_ParentProtoExpression", - desc: descriptor.Descriptor, is_repeated: bool, - name_as_field: StrStep, backing_str_tensor: Optional[tf.Tensor]): - """Initialize a _ProtoChildExpression. - This does not take a field descriptor so it can represent syntactic sugar - fields such as Any and Maps. +class _ProtoRootExpression(expression.Expression): + """The expression representing the parse of the root of a proto. - Args: - parent: the parent. - desc: the message descriptor of the submessage represented by this - expression. - is_repeated: whether the field is repeated. - name_as_field: the name of the field. - backing_str_tensor: a string tensor representing the root serialized - proto. This is passed to keep string_views of the tensor valid for - all children of the root expression + This class returns a _ProtoRootNodeTensor, that parses out fields for + _ProtoChildExpression and _ProtoLeafExpression to consume. """ - super().__init__(parent, name_as_field, is_repeated, None, - backing_str_tensor) - self._desc = desc - - def calculate_from_parsed_field( - self, parsed_field: struct2tensor_ops._ParsedField, # pylint:disable=protected-access - destinations: Sequence[expression.Expression], - options: calculate_options.Options) -> prensor.NodeTensor: - needed_fields = _get_needed_fields(destinations) - backing_str_tensor = None - if options.use_string_view: - backing_str_tensor = self._backing_str_tensor - fields = parse_message_level_ex.parse_message_level_ex( - parsed_field.value, - self._desc, - needed_fields, - backing_str_tensor=backing_str_tensor, - honor_proto3_optional_semantics=options - .experimental_honor_proto3_optional_semantics) - return _ProtoChildNodeTensor(parsed_field.index, self.is_repeated, fields) - - def calculation_equal(self, expr: expression.Expression) -> bool: - # Ensure that we're dealing with the _ProtoChildExpression and not any - # of its subclasses. - if type(expr) != _ProtoChildExpression: # noqa: E721 - return False - expr = cast(_ProtoChildExpression, expr) # Keep pytype happy. - return (self._desc == expr._desc and # pylint: disable=protected-access - self.name_as_field == expr.name_as_field) - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - return _get_child(self, self._desc, field_name, self._backing_str_tensor) - - def known_field_names(self) -> FrozenSet[path.Step]: - return _known_field_names_from_descriptor(self._desc) - - def __str__(self) -> str: - return f"_ProtoChildExpression: name_as_field: {str(self.name_as_field)} desc: {str(self._desc.full_name)} from {self._parent}" - -class _TransformProtoChildExpression(_ProtoChildExpression): - """Transforms the parent indices and values prior to parsing.""" - - def __init__(self, parent: "_ParentProtoExpression", - desc: descriptor.Descriptor, is_repeated: bool, - name_as_field: StrStep, transform_fn: TransformFn, - backing_str_tensor: Optional[tf.Tensor]): - super().__init__(parent, desc, is_repeated, name_as_field, - backing_str_tensor) - self._transform_fn = transform_fn - - @property - def transform_fn(self): - return self._transform_fn - - def calculate_from_parsed_field( - self, - parsed_field: struct2tensor_ops._ParsedField, # pylint:disable=protected-access - destinations: Sequence[expression.Expression], - options: calculate_options.Options) -> prensor.NodeTensor: - needed_fields = _get_needed_fields(destinations) - transformed_parent_indices, transformed_values = self._transform_fn( - parsed_field.index, parsed_field.value) - backing_str_tensor = None - if options.use_string_view: - backing_str_tensor = self._backing_str_tensor - fields = parse_message_level_ex.parse_message_level_ex( - transformed_values, - self._desc, - needed_fields, - backing_str_tensor=backing_str_tensor, - honor_proto3_optional_semantics=options - .experimental_honor_proto3_optional_semantics) - return _ProtoChildNodeTensor(transformed_parent_indices, self.is_repeated, - fields) - - def calculation_equal(self, expr: expression.Expression) -> bool: - return (isinstance(expr, _TransformProtoChildExpression) and - self._desc == expr._desc and # pylint: disable=protected-access - self.name_as_field == expr.name_as_field - and self.transform_fn is expr.transform_fn) - - def __str__(self) -> str: - return (f"_TransformProtoChildExpression: name_as_field: {str(self.name_as_field)} desc: {str(self._desc.full_name)} from {self._parent}" + def __init__( + self, + desc: descriptor.Descriptor, + tensor_of_protos: tf.Tensor, + message_format: str = "binary", + ): + """Initialize a proto expression. + + Args: + desc: the descriptor of the expression. + tensor_of_protos: a 1-D tensor to get the protos from. + message_format: Indicates the format of the protocol buffer: is one of + 'text' or 'binary'. + """ + super().__init__(True, None) + self._descriptor = desc + self._tensor_of_protos = tensor_of_protos + self._message_format = message_format + + def get_path(self) -> path.Path: + """Returns the path to the root of the proto.""" + return path.Path([]) + + def get_proto_source(self) -> Tuple[tf.Tensor, descriptor.Descriptor]: + """Returns the tensor of protos and the original descriptor.""" + return (self._tensor_of_protos, self._descriptor) + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> _ProtoRootNodeTensor: + if sources: + raise ValueError("_ProtoRootExpression has no sources") + size = tf.size(self._tensor_of_protos, out_type=tf.int64) + needed_fields = _get_needed_fields(destinations) + backing_str_tensor = None + if options.use_string_view: + assert self._message_format == "binary", ( + "`options.use_string_view` is only compatible with 'binary' message " + "format. Please create the root expression with " + "message_format='binary'." ) + backing_str_tensor = self._tensor_of_protos + fields = parse_message_level_ex.parse_message_level_ex( + self._tensor_of_protos, + self._descriptor, + needed_fields, + message_format=self._message_format, + backing_str_tensor=backing_str_tensor, + honor_proto3_optional_semantics=options.experimental_honor_proto3_optional_semantics, + ) + return _ProtoRootNodeTensor(size, fields) + def calculation_is_identity(self) -> bool: + return False -class _ProtoRootExpression(expression.Expression): - """The expression representing the parse of the root of a proto. + def calculation_equal(self, expr: expression.Expression) -> bool: + # TODO(martinz): In theory, we could check for the equality of the + # tensor_of_protos and the descriptors. + return self is expr - This class returns a _ProtoRootNodeTensor, that parses out fields for - _ProtoChildExpression and _ProtoLeafExpression to consume. - """ + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + return _get_child(self, self._descriptor, field_name, self._tensor_of_protos) - def __init__(self, - desc: descriptor.Descriptor, - tensor_of_protos: tf.Tensor, - message_format: str = "binary"): - """Initialize a proto expression. + def known_field_names(self) -> FrozenSet[path.Step]: + return _known_field_names_from_descriptor(self._descriptor) + + def __str__(self) -> str: + return f"_ProtoRootExpression: {str(self._descriptor.full_name)}" - Args: - desc: the descriptor of the expression. - tensor_of_protos: a 1-D tensor to get the protos from. - message_format: Indicates the format of the protocol buffer: is one of - 'text' or 'binary'. - """ - super().__init__(True, None) - self._descriptor = desc - self._tensor_of_protos = tensor_of_protos - self._message_format = message_format - - def get_path(self) -> path.Path: - """Returns the path to the root of the proto.""" - return path.Path([]) - - def get_proto_source(self) -> Tuple[tf.Tensor, descriptor.Descriptor]: - """Returns the tensor of protos and the original descriptor.""" - return (self._tensor_of_protos, self._descriptor) - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [] - - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> _ProtoRootNodeTensor: - if sources: - raise ValueError("_ProtoRootExpression has no sources") - size = tf.size(self._tensor_of_protos, out_type=tf.int64) - needed_fields = _get_needed_fields(destinations) - backing_str_tensor = None - if options.use_string_view: - assert self._message_format == "binary", ( - "`options.use_string_view` is only compatible with 'binary' message " - "format. Please create the root expression with " - "message_format='binary'.") - backing_str_tensor = self._tensor_of_protos - fields = parse_message_level_ex.parse_message_level_ex( - self._tensor_of_protos, - self._descriptor, - needed_fields, - message_format=self._message_format, - backing_str_tensor=backing_str_tensor, - honor_proto3_optional_semantics=options - .experimental_honor_proto3_optional_semantics) - return _ProtoRootNodeTensor(size, fields) - - def calculation_is_identity(self) -> bool: - return False - - def calculation_equal(self, expr: expression.Expression) -> bool: - # TODO(martinz): In theory, we could check for the equality of the - # tensor_of_protos and the descriptors. - return self is expr - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - return _get_child(self, self._descriptor, field_name, - self._tensor_of_protos) - - def known_field_names(self) -> FrozenSet[path.Step]: - return _known_field_names_from_descriptor(self._descriptor) - - def __str__(self) -> str: - return f"_ProtoRootExpression: {str(self._descriptor.full_name)}" - - -ProtoExpression = Union[_ProtoRootExpression, _ProtoChildExpression, # pylint: disable=invalid-name - _ProtoLeafExpression] + +ProtoExpression = Union[ + _ProtoRootExpression, + _ProtoChildExpression, # pylint: disable=invalid-name + _ProtoLeafExpression, +] _ParentProtoExpression = Union[_ProtoChildExpression, _ProtoRootExpression] def _known_field_names_from_descriptor( - desc: descriptor.Descriptor) -> FrozenSet[StrStep]: - return frozenset([field.name for field in desc.fields]) + desc: descriptor.Descriptor, +) -> FrozenSet[StrStep]: + return frozenset([field.name for field in desc.fields]) def _get_field_descriptor( - desc: descriptor.Descriptor, - field_name: ProtoFieldName) -> Optional[descriptor.FieldDescriptor]: - if path.is_extension(field_name): - try: - return desc.file.pool.FindExtensionByName( - path.get_raw_extension_name(field_name)) - except KeyError: - return None - return desc.fields_by_name.get(field_name) + desc: descriptor.Descriptor, field_name: ProtoFieldName +) -> Optional[descriptor.FieldDescriptor]: + if path.is_extension(field_name): + try: + return desc.file.pool.FindExtensionByName( + path.get_raw_extension_name(field_name) + ) + except KeyError: + return None + return desc.fields_by_name.get(field_name) def _get_any_child( - parent: Union[_ProtoChildExpression, - _ProtoRootExpression], desc: descriptor.Descriptor, - field_name: ProtoFieldName, backing_str_tensor: Optional[tf.Tensor] + parent: Union[_ProtoChildExpression, _ProtoRootExpression], + desc: descriptor.Descriptor, + field_name: ProtoFieldName, + backing_str_tensor: Optional[tf.Tensor], ) -> Optional[Union[_ProtoLeafExpression, _ProtoChildExpression]]: - """Gets the child of an any descriptor.""" - if path.is_extension(field_name): - full_name_child = parse_message_level_ex.get_full_name_from_any_step( - field_name) - if full_name_child is None: - return None - field_message = desc.file.pool.FindMessageTypeByName(full_name_child) - return _ProtoChildExpression(parent, field_message, False, field_name, - backing_str_tensor) - return _get_child_helper(parent, desc.fields_by_name.get(field_name), - field_name, backing_str_tensor) + """Gets the child of an any descriptor.""" + if path.is_extension(field_name): + full_name_child = parse_message_level_ex.get_full_name_from_any_step(field_name) + if full_name_child is None: + return None + field_message = desc.file.pool.FindMessageTypeByName(full_name_child) + return _ProtoChildExpression( + parent, field_message, False, field_name, backing_str_tensor + ) + return _get_child_helper( + parent, desc.fields_by_name.get(field_name), field_name, backing_str_tensor + ) def _is_map_field_desc(field_desc: descriptor.FieldDescriptor) -> bool: - return (field_desc.message_type and - field_desc.message_type.GetOptions().map_entry) + return field_desc.message_type and field_desc.message_type.GetOptions().map_entry def _get_map_child( @@ -612,59 +667,62 @@ def _get_map_child( field_name: ProtoFieldName, backing_str_tensor: Optional[tf.Tensor], ) -> Optional[Union[_ProtoLeafExpression, _ProtoChildExpression]]: - """Gets the child given a map field.""" - [map_field_name, _] = path.parse_map_indexing_step(field_name) - map_field_desc = desc.fields_by_name.get(map_field_name) - if map_field_desc is None: - return None - if not _is_map_field_desc(map_field_desc): - return None - map_message_desc = map_field_desc.message_type - if map_message_desc is None: - # Note: I don't know if this is reachable. Theoretically, _is_map_field_desc - # should have already returned false. - return None - value_field_desc = map_message_desc.fields_by_name.get("value") - if value_field_desc is None: - # Note: I don't know if this is reachable. Theoretically, _is_map_field_desc - # should have already returned false. - return None - # This relies on the fact that the value is an optional field. - return _get_child_helper(parent, value_field_desc, field_name, - backing_str_tensor) + """Gets the child given a map field.""" + [map_field_name, _] = path.parse_map_indexing_step(field_name) + map_field_desc = desc.fields_by_name.get(map_field_name) + if map_field_desc is None: + return None + if not _is_map_field_desc(map_field_desc): + return None + map_message_desc = map_field_desc.message_type + if map_message_desc is None: + # Note: I don't know if this is reachable. Theoretically, _is_map_field_desc + # should have already returned false. + return None + value_field_desc = map_message_desc.fields_by_name.get("value") + if value_field_desc is None: + # Note: I don't know if this is reachable. Theoretically, _is_map_field_desc + # should have already returned false. + return None + # This relies on the fact that the value is an optional field. + return _get_child_helper(parent, value_field_desc, field_name, backing_str_tensor) def _get_child_helper( parent: Union[_ProtoChildExpression, _ProtoRootExpression], field_descriptor: Optional[descriptor.FieldDescriptor], - field_name: ProtoFieldName, backing_str_tensor: Optional[tf.Tensor] + field_name: ProtoFieldName, + backing_str_tensor: Optional[tf.Tensor], ) -> Optional[Union[_ProtoChildExpression, _ProtoLeafExpression]]: - """Helper function for _get_child, _get_any_child, and _get_map_child. - - Note that the field_descriptor.field_name is not necessarily equal to - field_name, especially if this is called from _get_map_child. - - Args: - parent: the parent expression - field_descriptor: the field descriptor of the submessage represented by the - returned expression, if present. If None, this will just return None. - field_name: the field name of the _AbstractProtoChildExpression returned. - backing_str_tensor: a string tensor representing the root serialized proto. - This is passed to keep string_views of the tensor valid for all children - of the root expression - - Returns: - An _AbstractProtoChildExpression. - """ - if field_descriptor is None: - return None - field_message = field_descriptor.message_type - if field_message is None: - return _ProtoLeafExpression(parent, field_descriptor, field_name) - return _ProtoChildExpression( - parent, field_message, - field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED, - field_name, backing_str_tensor) + """Helper function for _get_child, _get_any_child, and _get_map_child. + + Note that the field_descriptor.field_name is not necessarily equal to + field_name, especially if this is called from _get_map_child. + + Args: + parent: the parent expression + field_descriptor: the field descriptor of the submessage represented by the + returned expression, if present. If None, this will just return None. + field_name: the field name of the _AbstractProtoChildExpression returned. + backing_str_tensor: a string tensor representing the root serialized proto. + This is passed to keep string_views of the tensor valid for all children + of the root expression + + Returns: + An _AbstractProtoChildExpression. + """ + if field_descriptor is None: + return None + field_message = field_descriptor.message_type + if field_message is None: + return _ProtoLeafExpression(parent, field_descriptor, field_name) + return _ProtoChildExpression( + parent, + field_message, + field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED, + field_name, + backing_str_tensor, + ) def _get_child( @@ -673,40 +731,40 @@ def _get_child( field_name: path.Step, backing_str_tensor: Optional[tf.Tensor], ) -> Optional[Union[_ProtoChildExpression, _ProtoLeafExpression]]: - """Get a child expression. - - This will get one of the following: - A regular field. - An extension. - An Any filtered by value. - A map field. - - Args: - parent: The parent expression. - desc: The descriptor of the parent. - field_name: The name of the field. - backing_str_tensor: a string tensor representing the root serialized proto. - This is passed to keep string_views of the tensor valid for all children - of the root expression - - Returns: - The child expression, either a submessage or a leaf. - """ - if isinstance(field_name, path.AnonymousId): - return None - if parse_message_level_ex.is_any_descriptor(desc): - return _get_any_child(parent, desc, field_name, backing_str_tensor) - if path.is_map_indexing_step(field_name): - return _get_map_child(parent, desc, field_name, backing_str_tensor) - # Works for extensions and regular fields, but not any or map. - return _get_child_helper(parent, _get_field_descriptor(desc, field_name), - field_name, backing_str_tensor) - - -def _get_needed_fields( - destinations: Sequence[expression.Expression]) -> Set[StrStep]: - field_names = set() # type: Set[StrStep] - for destination in destinations: - if isinstance(destination, _AbstractProtoChildExpression): - field_names.add(destination.name_as_field) - return field_names + """Get a child expression. + + This will get one of the following: + A regular field. + An extension. + An Any filtered by value. + A map field. + + Args: + parent: The parent expression. + desc: The descriptor of the parent. + field_name: The name of the field. + backing_str_tensor: a string tensor representing the root serialized proto. + This is passed to keep string_views of the tensor valid for all children + of the root expression + + Returns: + The child expression, either a submessage or a leaf. + """ + if isinstance(field_name, path.AnonymousId): + return None + if parse_message_level_ex.is_any_descriptor(desc): + return _get_any_child(parent, desc, field_name, backing_str_tensor) + if path.is_map_indexing_step(field_name): + return _get_map_child(parent, desc, field_name, backing_str_tensor) + # Works for extensions and regular fields, but not any or map. + return _get_child_helper( + parent, _get_field_descriptor(desc, field_name), field_name, backing_str_tensor + ) + + +def _get_needed_fields(destinations: Sequence[expression.Expression]) -> Set[StrStep]: + field_names = set() # type: Set[StrStep] + for destination in destinations: + if isinstance(destination, _AbstractProtoChildExpression): + field_names.add(destination.name_as_field) + return field_names diff --git a/struct2tensor/expression_impl/proto_test.py b/struct2tensor/expression_impl/proto_test.py index 88927fa..e6938a1 100644 --- a/struct2tensor/expression_impl/proto_test.py +++ b/struct2tensor/expression_impl/proto_test.py @@ -16,318 +16,346 @@ import tensorflow as tf from absl.testing import absltest, parameterized from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor import calculate_options, path from struct2tensor.expression_impl import proto, proto_test_util from struct2tensor.test import ( - expression_test_util, - test_any_pb2, - test_extension_pb2, - test_map_pb2, - test_pb2, + expression_test_util, + test_any_pb2, + test_extension_pb2, + test_map_pb2, + test_pb2, ) def _get_expression_with_any(): - my_any_0 = test_any_pb2.MessageWithAny() - my_value_0 = test_pb2.AllSimple() - my_value_0.optional_int32 = 0 - my_any_0.my_any.Pack(my_value_0) - my_any_1 = test_any_pb2.MessageWithAny() - my_value_1 = test_pb2.UserInfo() - my_any_1.my_any.Pack(my_value_1) - my_any_2 = test_any_pb2.MessageWithAny() - my_value_2 = test_pb2.AllSimple() - my_value_2.optional_int32 = 20 - my_any_2.my_any.Pack(my_value_2) - serialized = [x.SerializeToString() for x in [my_any_0, my_any_1, my_any_2]] - return proto.create_expression_from_proto( - serialized, test_any_pb2.MessageWithAny.DESCRIPTOR) + my_any_0 = test_any_pb2.MessageWithAny() + my_value_0 = test_pb2.AllSimple() + my_value_0.optional_int32 = 0 + my_any_0.my_any.Pack(my_value_0) + my_any_1 = test_any_pb2.MessageWithAny() + my_value_1 = test_pb2.UserInfo() + my_any_1.my_any.Pack(my_value_1) + my_any_2 = test_any_pb2.MessageWithAny() + my_value_2 = test_pb2.AllSimple() + my_value_2.optional_int32 = 20 + my_any_2.my_any.Pack(my_value_2) + serialized = [x.SerializeToString() for x in [my_any_0, my_any_1, my_any_2]] + return proto.create_expression_from_proto( + serialized, test_any_pb2.MessageWithAny.DESCRIPTOR + ) def _get_user_info_with_extension(): - my_user_info = test_pb2.UserInfo() - my_user_info.Extensions[ - test_extension_pb2.MyExternalExtension.ext].special = "shhh" + my_user_info = test_pb2.UserInfo() + my_user_info.Extensions[test_extension_pb2.MyExternalExtension.ext].special = "shhh" - serialized = [my_user_info.SerializeToString()] - return proto.create_expression_from_proto( - serialized, test_any_pb2.MessageWithAny.DESCRIPTOR) + serialized = [my_user_info.SerializeToString()] + return proto.create_expression_from_proto( + serialized, test_any_pb2.MessageWithAny.DESCRIPTOR + ) class ProtoTest(absltest.TestCase): - - def test_create_expression_from_proto_with_event(self): - expr = proto_test_util._get_expression_from_session_empty_user_info() - event_expr = expr.get_child_or_error("event") - self.assertTrue(event_expr.is_repeated) - self.assertIsNone(event_expr.type) - self.assertFalse(event_expr.is_leaf) - self.assertFalse(event_expr.calculation_is_identity()) - self.assertTrue(event_expr.calculation_equal(event_expr)) - self.assertFalse(event_expr.calculation_equal(expr)) - child_node = expression_test_util.calculate_value_slowly(event_expr) - self.assertEqual(child_node.parent_index.dtype, tf.int64) - self.assertEqual( - event_expr.known_field_names(), - frozenset({ - "event_id", "query", "query_token", "action", "user_info", - "action_mask" - })) - - sources = event_expr.get_source_expressions() - self.assertLen(sources, 1) - self.assertIs(expr, sources[0]) - - def test_create_expression_from_proto_with_root(self): - expr = proto_test_util._get_expression_from_session_empty_user_info() - self.assertTrue(expr.is_repeated) - self.assertIsNone(expr.type) - self.assertFalse(expr.is_leaf) - self.assertFalse(expr.calculation_is_identity()) - self.assertTrue(expr.calculation_equal(expr)) - self.assertFalse(expr.calculation_equal(expr.get_child_or_error("event"))) - root_node = expression_test_util.calculate_value_slowly(expr) - self.assertEqual(root_node.size.dtype, tf.int64) - self.assertEqual(expr.known_field_names(), - frozenset({"event", "session_id", "session_info"})) - - sources = expr.get_source_expressions() - self.assertEmpty(sources) - - def test_user_info_with_extension(self): - expr = _get_user_info_with_extension() - ext_expr = expr.get_child_or_error( - "(struct2tensor.test.MyExternalExtension.ext)") - self.assertFalse(ext_expr.is_repeated) - self.assertIsNone(ext_expr.type) - self.assertFalse(ext_expr.is_leaf) - self.assertFalse(ext_expr.calculation_is_identity()) - self.assertTrue(ext_expr.calculation_equal(ext_expr)) - self.assertFalse(ext_expr.calculation_equal(expr)) - child_node = expression_test_util.calculate_value_slowly(ext_expr) - self.assertEqual(child_node.parent_index.dtype, tf.int64) - self.assertEqual(ext_expr.known_field_names(), frozenset({"special"})) - - def test_missing_extension(self): - """Tests a missing extension on a deep tree.""" - expr = proto_test_util._get_expression_from_session_empty_user_info() - missing_expr = expr.get_child("(ext.NotPresent)") - self.assertIsNone(missing_expr) - - def test_create_expression_from_proto_with_any(self): - """Test an any field.""" - expr = _get_expression_with_any() - any_expr = expr.get_child_or_error("my_any") - simple_expr = expr.get_descendant_or_error( - path.Path( - ["my_any", "(type.googleapis.com/struct2tensor.test.AllSimple)"])) - self.assertFalse(simple_expr.is_repeated) - self.assertIsNone(simple_expr.type) - self.assertFalse(simple_expr.is_leaf) - self.assertFalse(simple_expr.calculation_is_identity()) - self.assertTrue(simple_expr.calculation_equal(simple_expr)) - self.assertFalse(simple_expr.calculation_equal(expr)) - child_node = expression_test_util.calculate_value_slowly(simple_expr) - self.assertEqual(child_node.parent_index.dtype, tf.int64) - self.assertEqual( - simple_expr.known_field_names(), - frozenset({ - "optional_string", "optional_uint64", "repeated_uint64", - "repeated_int32", "repeated_string", "optional_int32", - "optional_float", "repeated_int64", "optional_uint32", - "repeated_float", "repeated_uint32", "optional_double", - "optional_int64", "repeated_double" - })) - - sources = simple_expr.get_source_expressions() - self.assertLen(sources, 1) - self.assertIs(any_expr, sources[0]) - - def test_create_expression_from_proto_with_any_missing_message(self): - """Test an any field.""" - expr = _get_expression_with_any() - any_expr = expr.get_child_or_error("my_any") - simple_expr = expr.get_descendant_or_error( - path.Path([ - "my_any", "(type.googleapis.com/struct2tensor.test.SpecialUserInfo)" - ])) - self.assertFalse(simple_expr.is_repeated) - self.assertIsNone(simple_expr.type) - self.assertFalse(simple_expr.is_leaf) - self.assertFalse(simple_expr.calculation_is_identity()) - self.assertTrue(simple_expr.calculation_equal(simple_expr)) - self.assertFalse(simple_expr.calculation_equal(expr)) - child_node = expression_test_util.calculate_value_slowly(simple_expr) - self.assertEqual(child_node.parent_index.dtype, tf.int64) - self.assertEqual(simple_expr.known_field_names(), frozenset({"secret"})) - - sources = simple_expr.get_source_expressions() - self.assertLen(sources, 1) - self.assertIs(any_expr, sources[0]) - - def test_create_expression_from_proto_with_any_type_url(self): - """Test an any with type_url.""" - expr = _get_expression_with_any() - any_expr = expr.get_child_or_error("my_any") - simple_expr = expr.get_descendant_or_error( - path.Path(["my_any", "type_url"])) - self.assertFalse(simple_expr.is_repeated) - self.assertEqual(simple_expr.type, tf.string) - self.assertTrue(simple_expr.is_leaf) - self.assertFalse(simple_expr.calculation_is_identity()) - self.assertTrue(simple_expr.calculation_equal(simple_expr)) - self.assertFalse(simple_expr.calculation_equal(expr)) - leaf_node = expression_test_util.calculate_value_slowly(simple_expr) - self.assertEqual(leaf_node.parent_index.dtype, tf.int64) - self.assertEqual(leaf_node.values.dtype, tf.string) - self.assertEqual(simple_expr.known_field_names(), frozenset({})) - - sources = simple_expr.get_source_expressions() - self.assertLen(sources, 1) - self.assertIs(any_expr, sources[0]) - - def test_create_expression_from_proto_with_any_value(self): - """Test an any with value.""" - expr = _get_expression_with_any() - any_expr = expr.get_child_or_error("my_any") - simple_expr = expr.get_descendant_or_error(path.Path(["my_any", "value"])) - self.assertFalse(simple_expr.is_repeated) - self.assertEqual(simple_expr.type, tf.string) - self.assertTrue(simple_expr.is_leaf) - self.assertFalse(simple_expr.calculation_is_identity()) - self.assertTrue(simple_expr.calculation_equal(simple_expr)) - self.assertFalse(simple_expr.calculation_equal(expr)) - leaf_node = expression_test_util.calculate_value_slowly(simple_expr) - self.assertEqual(leaf_node.parent_index.dtype, tf.int64) - self.assertEqual(leaf_node.values.dtype, tf.string) - self.assertEqual(simple_expr.known_field_names(), frozenset({})) - - sources = simple_expr.get_source_expressions() - self.assertLen(sources, 1) - self.assertIs(any_expr, sources[0]) - - def test_create_transformed_field(self): - expr = proto_test_util._get_expression_from_session_empty_user_info() - reversed_events_expr = proto.create_transformed_field( - expr, path.Path(["event"]), "reversed_event", _reverse_values) - source_events = expr.get_child_or_error("event") - dest_events = reversed_events_expr.get_child_or_error("reversed_event") - self.assertTrue(dest_events.is_repeated) - self.assertFalse(dest_events.is_leaf) - self.assertEqual(source_events.type, dest_events.type) - leaf_expr = reversed_events_expr.get_descendant_or_error( - path.Path(["reversed_event", "action", "doc_id"])) - leaf_tensor = expression_test_util.calculate_value_slowly(leaf_expr) - self.assertEqual(leaf_tensor.parent_index.dtype, tf.int64) - self.assertEqual(leaf_tensor.values.dtype, tf.string) - - def test_create_reversed_field_nested(self): - expr = proto_test_util._get_expression_from_session_empty_user_info() - first_reverse = proto.create_transformed_field(expr, path.Path(["event"]), - "reversed_event", - _reverse_values) - second_reverse = proto.create_transformed_field( - first_reverse, path.Path(["reversed_event", "action"]), - "reversed_action", _reverse_values) - leaf_expr = second_reverse.get_descendant_or_error( - path.Path(["reversed_event", "reversed_action", "doc_id"])) - leaf_tensor = expression_test_util.calculate_value_slowly(leaf_expr) - self.assertEqual(leaf_tensor.parent_index.dtype, tf.int64) - self.assertEqual(leaf_tensor.values.dtype, tf.string) + def test_create_expression_from_proto_with_event(self): + expr = proto_test_util._get_expression_from_session_empty_user_info() + event_expr = expr.get_child_or_error("event") + self.assertTrue(event_expr.is_repeated) + self.assertIsNone(event_expr.type) + self.assertFalse(event_expr.is_leaf) + self.assertFalse(event_expr.calculation_is_identity()) + self.assertTrue(event_expr.calculation_equal(event_expr)) + self.assertFalse(event_expr.calculation_equal(expr)) + child_node = expression_test_util.calculate_value_slowly(event_expr) + self.assertEqual(child_node.parent_index.dtype, tf.int64) + self.assertEqual( + event_expr.known_field_names(), + frozenset( + { + "event_id", + "query", + "query_token", + "action", + "user_info", + "action_mask", + } + ), + ) + + sources = event_expr.get_source_expressions() + self.assertLen(sources, 1) + self.assertIs(expr, sources[0]) + + def test_create_expression_from_proto_with_root(self): + expr = proto_test_util._get_expression_from_session_empty_user_info() + self.assertTrue(expr.is_repeated) + self.assertIsNone(expr.type) + self.assertFalse(expr.is_leaf) + self.assertFalse(expr.calculation_is_identity()) + self.assertTrue(expr.calculation_equal(expr)) + self.assertFalse(expr.calculation_equal(expr.get_child_or_error("event"))) + root_node = expression_test_util.calculate_value_slowly(expr) + self.assertEqual(root_node.size.dtype, tf.int64) + self.assertEqual( + expr.known_field_names(), frozenset({"event", "session_id", "session_info"}) + ) + + sources = expr.get_source_expressions() + self.assertEmpty(sources) + + def test_user_info_with_extension(self): + expr = _get_user_info_with_extension() + ext_expr = expr.get_child_or_error( + "(struct2tensor.test.MyExternalExtension.ext)" + ) + self.assertFalse(ext_expr.is_repeated) + self.assertIsNone(ext_expr.type) + self.assertFalse(ext_expr.is_leaf) + self.assertFalse(ext_expr.calculation_is_identity()) + self.assertTrue(ext_expr.calculation_equal(ext_expr)) + self.assertFalse(ext_expr.calculation_equal(expr)) + child_node = expression_test_util.calculate_value_slowly(ext_expr) + self.assertEqual(child_node.parent_index.dtype, tf.int64) + self.assertEqual(ext_expr.known_field_names(), frozenset({"special"})) + + def test_missing_extension(self): + """Tests a missing extension on a deep tree.""" + expr = proto_test_util._get_expression_from_session_empty_user_info() + missing_expr = expr.get_child("(ext.NotPresent)") + self.assertIsNone(missing_expr) + + def test_create_expression_from_proto_with_any(self): + """Test an any field.""" + expr = _get_expression_with_any() + any_expr = expr.get_child_or_error("my_any") + simple_expr = expr.get_descendant_or_error( + path.Path(["my_any", "(type.googleapis.com/struct2tensor.test.AllSimple)"]) + ) + self.assertFalse(simple_expr.is_repeated) + self.assertIsNone(simple_expr.type) + self.assertFalse(simple_expr.is_leaf) + self.assertFalse(simple_expr.calculation_is_identity()) + self.assertTrue(simple_expr.calculation_equal(simple_expr)) + self.assertFalse(simple_expr.calculation_equal(expr)) + child_node = expression_test_util.calculate_value_slowly(simple_expr) + self.assertEqual(child_node.parent_index.dtype, tf.int64) + self.assertEqual( + simple_expr.known_field_names(), + frozenset( + { + "optional_string", + "optional_uint64", + "repeated_uint64", + "repeated_int32", + "repeated_string", + "optional_int32", + "optional_float", + "repeated_int64", + "optional_uint32", + "repeated_float", + "repeated_uint32", + "optional_double", + "optional_int64", + "repeated_double", + } + ), + ) + + sources = simple_expr.get_source_expressions() + self.assertLen(sources, 1) + self.assertIs(any_expr, sources[0]) + + def test_create_expression_from_proto_with_any_missing_message(self): + """Test an any field.""" + expr = _get_expression_with_any() + any_expr = expr.get_child_or_error("my_any") + simple_expr = expr.get_descendant_or_error( + path.Path( + ["my_any", "(type.googleapis.com/struct2tensor.test.SpecialUserInfo)"] + ) + ) + self.assertFalse(simple_expr.is_repeated) + self.assertIsNone(simple_expr.type) + self.assertFalse(simple_expr.is_leaf) + self.assertFalse(simple_expr.calculation_is_identity()) + self.assertTrue(simple_expr.calculation_equal(simple_expr)) + self.assertFalse(simple_expr.calculation_equal(expr)) + child_node = expression_test_util.calculate_value_slowly(simple_expr) + self.assertEqual(child_node.parent_index.dtype, tf.int64) + self.assertEqual(simple_expr.known_field_names(), frozenset({"secret"})) + + sources = simple_expr.get_source_expressions() + self.assertLen(sources, 1) + self.assertIs(any_expr, sources[0]) + + def test_create_expression_from_proto_with_any_type_url(self): + """Test an any with type_url.""" + expr = _get_expression_with_any() + any_expr = expr.get_child_or_error("my_any") + simple_expr = expr.get_descendant_or_error(path.Path(["my_any", "type_url"])) + self.assertFalse(simple_expr.is_repeated) + self.assertEqual(simple_expr.type, tf.string) + self.assertTrue(simple_expr.is_leaf) + self.assertFalse(simple_expr.calculation_is_identity()) + self.assertTrue(simple_expr.calculation_equal(simple_expr)) + self.assertFalse(simple_expr.calculation_equal(expr)) + leaf_node = expression_test_util.calculate_value_slowly(simple_expr) + self.assertEqual(leaf_node.parent_index.dtype, tf.int64) + self.assertEqual(leaf_node.values.dtype, tf.string) + self.assertEqual(simple_expr.known_field_names(), frozenset({})) + + sources = simple_expr.get_source_expressions() + self.assertLen(sources, 1) + self.assertIs(any_expr, sources[0]) + + def test_create_expression_from_proto_with_any_value(self): + """Test an any with value.""" + expr = _get_expression_with_any() + any_expr = expr.get_child_or_error("my_any") + simple_expr = expr.get_descendant_or_error(path.Path(["my_any", "value"])) + self.assertFalse(simple_expr.is_repeated) + self.assertEqual(simple_expr.type, tf.string) + self.assertTrue(simple_expr.is_leaf) + self.assertFalse(simple_expr.calculation_is_identity()) + self.assertTrue(simple_expr.calculation_equal(simple_expr)) + self.assertFalse(simple_expr.calculation_equal(expr)) + leaf_node = expression_test_util.calculate_value_slowly(simple_expr) + self.assertEqual(leaf_node.parent_index.dtype, tf.int64) + self.assertEqual(leaf_node.values.dtype, tf.string) + self.assertEqual(simple_expr.known_field_names(), frozenset({})) + + sources = simple_expr.get_source_expressions() + self.assertLen(sources, 1) + self.assertIs(any_expr, sources[0]) + + def test_create_transformed_field(self): + expr = proto_test_util._get_expression_from_session_empty_user_info() + reversed_events_expr = proto.create_transformed_field( + expr, path.Path(["event"]), "reversed_event", _reverse_values + ) + source_events = expr.get_child_or_error("event") + dest_events = reversed_events_expr.get_child_or_error("reversed_event") + self.assertTrue(dest_events.is_repeated) + self.assertFalse(dest_events.is_leaf) + self.assertEqual(source_events.type, dest_events.type) + leaf_expr = reversed_events_expr.get_descendant_or_error( + path.Path(["reversed_event", "action", "doc_id"]) + ) + leaf_tensor = expression_test_util.calculate_value_slowly(leaf_expr) + self.assertEqual(leaf_tensor.parent_index.dtype, tf.int64) + self.assertEqual(leaf_tensor.values.dtype, tf.string) + + def test_create_reversed_field_nested(self): + expr = proto_test_util._get_expression_from_session_empty_user_info() + first_reverse = proto.create_transformed_field( + expr, path.Path(["event"]), "reversed_event", _reverse_values + ) + second_reverse = proto.create_transformed_field( + first_reverse, + path.Path(["reversed_event", "action"]), + "reversed_action", + _reverse_values, + ) + leaf_expr = second_reverse.get_descendant_or_error( + path.Path(["reversed_event", "reversed_action", "doc_id"]) + ) + leaf_tensor = expression_test_util.calculate_value_slowly(leaf_expr) + self.assertEqual(leaf_tensor.parent_index.dtype, tf.int64) + self.assertEqual(leaf_tensor.values.dtype, tf.string) @test_util.run_all_in_graph_and_eager_modes class ProtoValuesTest(tf.test.TestCase, parameterized.TestCase): - - def _get_calculate_options(self, use_string_view): - options = calculate_options.get_default_options() - options.use_string_view = use_string_view - return options - - def _check_string_view(self): - for op in tf.compat.v1.get_default_graph().get_operations(): - if op.type.startswith("DecodeProtoSparse"): - self.assertLen(op.inputs, 2) - if op.type.startswith("DecodeProtoMap"): - self.assertLen(op.inputs, 3) - - @parameterized.named_parameters(("string_view", True), - ("no_string_view", False)) - def test_create_expression_from_proto_and_calculate_root_value( - self, use_string_view): - """Tests get_sparse_tensors on a deep tree.""" - expr = proto_test_util._get_expression_from_session_empty_user_info() - root_value = expression_test_util.calculate_value_slowly( - expr, options=self._get_calculate_options(use_string_view)) - # For some reason, this fails on tf.eager. It could be because it is - # a scalar, I don't know. - self.assertEqual(self.evaluate(root_value.size), 2) - if use_string_view: - self._check_string_view() - - @parameterized.named_parameters(("string_view", True), - ("no_string_view", False)) - def test_create_expression_from_proto_and_calculate_event_value( - self, use_string_view): - """Tests get_sparse_tensors on a deep tree.""" - expr = proto_test_util._get_expression_from_session_empty_user_info() - event_value = expression_test_util.calculate_value_slowly( - expr.get_child_or_error("event"), - options=self._get_calculate_options(use_string_view)) - self.assertAllEqual(event_value.parent_index, [0, 0, 0, 1, 1]) - if use_string_view: - self._check_string_view() - - @parameterized.named_parameters(("string_view", True), - ("no_string_view", False)) - def test_create_expression_from_proto_and_calculate_event_id_value( - self, use_string_view): - """Tests get_sparse_tensors on a deep tree.""" - expr = proto_test_util._get_expression_from_session_empty_user_info() - event_id_value = expression_test_util.calculate_value_slowly( - expr.get_descendant_or_error(path.Path(["event", "event_id"])), - options=self._get_calculate_options(use_string_view)) - self.assertAllEqual(event_id_value.parent_index, [0, 1, 2, 4]) - self.assertAllEqual(event_id_value.values, [b"A", b"B", b"C", b"D"]) - if use_string_view: - self._check_string_view() - - @parameterized.named_parameters(("string_view", True), - ("no_string_view", False)) - def test_create_expression_from_proto_with_any(self, use_string_view): - """Test an any field.""" - expr = _get_expression_with_any() - simple_expr = expr.get_descendant_or_error( - path.Path( - ["my_any", "(type.googleapis.com/struct2tensor.test.AllSimple)"])) - child_node = expression_test_util.calculate_value_slowly( - simple_expr, - options=self._get_calculate_options(use_string_view)).parent_index - self.assertAllEqual(child_node, [0, 2]) - - @parameterized.named_parameters(("string_view", True), - ("no_string_view", False)) - def test_create_expression_from_proto_with_any_missing_message( - self, use_string_view): - """Test an any field with a message that is absent.""" - expr = _get_expression_with_any() - simple_expr = expr.get_descendant_or_error( - path.Path([ - "my_any", "(type.googleapis.com/struct2tensor.test.SpecialUserInfo)" - ])) - child_node = expression_test_util.calculate_value_slowly( - simple_expr, - options=self._get_calculate_options(use_string_view)).parent_index - self.assertAllEqual(child_node, []) - - @parameterized.named_parameters(("string_view", True), - ("no_string_view", False)) - def test_project_proto_map(self, use_string_view): - examples = [ - """ + def _get_calculate_options(self, use_string_view): + options = calculate_options.get_default_options() + options.use_string_view = use_string_view + return options + + def _check_string_view(self): + for op in tf.compat.v1.get_default_graph().get_operations(): + if op.type.startswith("DecodeProtoSparse"): + self.assertLen(op.inputs, 2) + if op.type.startswith("DecodeProtoMap"): + self.assertLen(op.inputs, 3) + + @parameterized.named_parameters(("string_view", True), ("no_string_view", False)) + def test_create_expression_from_proto_and_calculate_root_value( + self, use_string_view + ): + """Tests get_sparse_tensors on a deep tree.""" + expr = proto_test_util._get_expression_from_session_empty_user_info() + root_value = expression_test_util.calculate_value_slowly( + expr, options=self._get_calculate_options(use_string_view) + ) + # For some reason, this fails on tf.eager. It could be because it is + # a scalar, I don't know. + self.assertEqual(self.evaluate(root_value.size), 2) + if use_string_view: + self._check_string_view() + + @parameterized.named_parameters(("string_view", True), ("no_string_view", False)) + def test_create_expression_from_proto_and_calculate_event_value( + self, use_string_view + ): + """Tests get_sparse_tensors on a deep tree.""" + expr = proto_test_util._get_expression_from_session_empty_user_info() + event_value = expression_test_util.calculate_value_slowly( + expr.get_child_or_error("event"), + options=self._get_calculate_options(use_string_view), + ) + self.assertAllEqual(event_value.parent_index, [0, 0, 0, 1, 1]) + if use_string_view: + self._check_string_view() + + @parameterized.named_parameters(("string_view", True), ("no_string_view", False)) + def test_create_expression_from_proto_and_calculate_event_id_value( + self, use_string_view + ): + """Tests get_sparse_tensors on a deep tree.""" + expr = proto_test_util._get_expression_from_session_empty_user_info() + event_id_value = expression_test_util.calculate_value_slowly( + expr.get_descendant_or_error(path.Path(["event", "event_id"])), + options=self._get_calculate_options(use_string_view), + ) + self.assertAllEqual(event_id_value.parent_index, [0, 1, 2, 4]) + self.assertAllEqual(event_id_value.values, [b"A", b"B", b"C", b"D"]) + if use_string_view: + self._check_string_view() + + @parameterized.named_parameters(("string_view", True), ("no_string_view", False)) + def test_create_expression_from_proto_with_any(self, use_string_view): + """Test an any field.""" + expr = _get_expression_with_any() + simple_expr = expr.get_descendant_or_error( + path.Path(["my_any", "(type.googleapis.com/struct2tensor.test.AllSimple)"]) + ) + child_node = expression_test_util.calculate_value_slowly( + simple_expr, options=self._get_calculate_options(use_string_view) + ).parent_index + self.assertAllEqual(child_node, [0, 2]) + + @parameterized.named_parameters(("string_view", True), ("no_string_view", False)) + def test_create_expression_from_proto_with_any_missing_message( + self, use_string_view + ): + """Test an any field with a message that is absent.""" + expr = _get_expression_with_any() + simple_expr = expr.get_descendant_or_error( + path.Path( + ["my_any", "(type.googleapis.com/struct2tensor.test.SpecialUserInfo)"] + ) + ) + child_node = expression_test_util.calculate_value_slowly( + simple_expr, options=self._get_calculate_options(use_string_view) + ).parent_index + self.assertAllEqual(child_node, []) + + @parameterized.named_parameters(("string_view", True), ("no_string_view", False)) + def test_project_proto_map(self, use_string_view): + examples = [ + """ features { feature { key: "feature1" @@ -338,7 +366,8 @@ def test_project_proto_map(self, use_string_view): value { float_list { value: 8.0 } } } } - """, """ + """, + """ features { feature { key: "feature1" @@ -349,119 +378,141 @@ def test_project_proto_map(self, use_string_view): value { int64_list { value: [123, 456] } } } } - """ - ] - expr = proto_test_util.text_to_expression(examples, tf.train.Example) - result = expression_test_util.calculate_list_map( - expr.project([ - "features.feature[feature1].bytes_list.value", - "features.feature[feature2].float_list.value", - "features.feature[feature3].int64_list.value", - ]), - self, - options=self._get_calculate_options(use_string_view)) - - feature1 = result["features.feature[feature1].bytes_list.value"] - feature2 = result["features.feature[feature2].float_list.value"] - feature3 = result["features.feature[feature3].int64_list.value"] - self.assertAllEqual(feature1, - [[[[[b"hello", b"world"]]]], [[[[b"deadbeef"]]]]]) - self.assertAllEqual(feature2, [[[[[8.0]]]], [[]]]) - self.assertAllEqual(feature3, [[[]], [[[[123, 456]]]]]) - if use_string_view: - self._check_string_view() - - @parameterized.named_parameters(("string_view", True), - ("no_string_view", False)) - def test_project_proto_map_leaf_value(self, use_string_view): - protos = [ - """ + """, + ] + expr = proto_test_util.text_to_expression(examples, tf.train.Example) + result = expression_test_util.calculate_list_map( + expr.project( + [ + "features.feature[feature1].bytes_list.value", + "features.feature[feature2].float_list.value", + "features.feature[feature3].int64_list.value", + ] + ), + self, + options=self._get_calculate_options(use_string_view), + ) + + feature1 = result["features.feature[feature1].bytes_list.value"] + feature2 = result["features.feature[feature2].float_list.value"] + feature3 = result["features.feature[feature3].int64_list.value"] + self.assertAllEqual(feature1, [[[[[b"hello", b"world"]]]], [[[[b"deadbeef"]]]]]) + self.assertAllEqual(feature2, [[[[[8.0]]]], [[]]]) + self.assertAllEqual(feature3, [[[]], [[[[123, 456]]]]]) + if use_string_view: + self._check_string_view() + + @parameterized.named_parameters(("string_view", True), ("no_string_view", False)) + def test_project_proto_map_leaf_value(self, use_string_view): + protos = [ + """ int32_string_map { key: 222 value: "2" } """ - ] - - expr = proto_test_util.text_to_expression(protos, - test_map_pb2.MessageWithMap) - result = expression_test_util.calculate_list_map( - expr.project([ - "int32_string_map[222]", - "int32_string_map[223]", - ]), - self, - options=self._get_calculate_options(use_string_view)) - self.assertLen(result, 2) - self.assertAllEqual(result["int32_string_map[222]"], [[b"2"]]) - self.assertAllEqual(result["int32_string_map[223]"], [[]]) - if use_string_view: - self._check_string_view() - - @parameterized.named_parameters(("string_view", True), - ("no_string_view", False)) - def test_transformed_field_values(self, use_string_view): - expr = proto_test_util._get_expression_from_session_empty_user_info() - reversed_events_expr = proto.create_transformed_field( - expr, path.Path(["event"]), "reversed_event", _reverse_values) - result = expression_test_util.calculate_list_map( - reversed_events_expr.project(["reversed_event.action.doc_id"]), - self, - options=self._get_calculate_options(use_string_view)) - self.assertAllEqual(result["reversed_event.action.doc_id"], - [[[[b"h"], [b"i"], [b"j"]], [[b"g"]], [[b"e"], [b"f"]]], - [[[b"c"], []], [[b"a"], [b"b"]]]]) - if use_string_view: - self._check_string_view() - - @parameterized.named_parameters(("string_view", True), - ("no_string_view", False)) - def test_transformed_field_values_with_transformed_parent( - self, use_string_view): - expr = proto_test_util._get_expression_from_session_empty_user_info() - first_reversed_expr = proto.create_transformed_field( - expr, path.Path(["event"]), "reversed_event", _reverse_values) - second_reversed_expr = proto.create_transformed_field( - first_reversed_expr, path.Path(["reversed_event", "action"]), - "reversed_action", _reverse_values) - result = expression_test_util.calculate_list_map( - second_reversed_expr.project(["reversed_event.reversed_action.doc_id"]), - self, - options=self._get_calculate_options(use_string_view)) - self.assertAllEqual(result["reversed_event.reversed_action.doc_id"], - [[[[b"b"], [b"a"], []], [[b"c"]], [[b"f"], [b"e"]]], - [[[b"g"], [b"j"]], [[b"i"], [b"h"]]]]) - if use_string_view: - self._check_string_view() - - @parameterized.named_parameters(("string_view", True), - ("no_string_view", False)) - def test_transformed_field_values_with_multiple_transforms( - self, use_string_view): - expr = proto_test_util._get_expression_from_session_empty_user_info() - reversed_events_expr = proto.create_transformed_field( - expr, path.Path(["event"]), "reversed_event", _reverse_values) - reversed_events_again_expr = proto.create_transformed_field( - reversed_events_expr, path.Path(["reversed_event"]), - "reversed_reversed_event", _reverse_values) - - result = expression_test_util.calculate_list_map( - reversed_events_again_expr.project( - ["reversed_reversed_event.action.doc_id"]), - self, - options=self._get_calculate_options(use_string_view)) - self.assertAllEqual(result["reversed_reversed_event.action.doc_id"], - [[[[b"a"], [b"b"]], [[b"c"], []], [[b"e"], [b"f"]]], - [[[b"g"]], [[b"h"], [b"i"], [b"j"]]]]) - if use_string_view: - self._check_string_view() - + ] + + expr = proto_test_util.text_to_expression(protos, test_map_pb2.MessageWithMap) + result = expression_test_util.calculate_list_map( + expr.project( + [ + "int32_string_map[222]", + "int32_string_map[223]", + ] + ), + self, + options=self._get_calculate_options(use_string_view), + ) + self.assertLen(result, 2) + self.assertAllEqual(result["int32_string_map[222]"], [[b"2"]]) + self.assertAllEqual(result["int32_string_map[223]"], [[]]) + if use_string_view: + self._check_string_view() + + @parameterized.named_parameters(("string_view", True), ("no_string_view", False)) + def test_transformed_field_values(self, use_string_view): + expr = proto_test_util._get_expression_from_session_empty_user_info() + reversed_events_expr = proto.create_transformed_field( + expr, path.Path(["event"]), "reversed_event", _reverse_values + ) + result = expression_test_util.calculate_list_map( + reversed_events_expr.project(["reversed_event.action.doc_id"]), + self, + options=self._get_calculate_options(use_string_view), + ) + self.assertAllEqual( + result["reversed_event.action.doc_id"], + [ + [[[b"h"], [b"i"], [b"j"]], [[b"g"]], [[b"e"], [b"f"]]], + [[[b"c"], []], [[b"a"], [b"b"]]], + ], + ) + if use_string_view: + self._check_string_view() + + @parameterized.named_parameters(("string_view", True), ("no_string_view", False)) + def test_transformed_field_values_with_transformed_parent(self, use_string_view): + expr = proto_test_util._get_expression_from_session_empty_user_info() + first_reversed_expr = proto.create_transformed_field( + expr, path.Path(["event"]), "reversed_event", _reverse_values + ) + second_reversed_expr = proto.create_transformed_field( + first_reversed_expr, + path.Path(["reversed_event", "action"]), + "reversed_action", + _reverse_values, + ) + result = expression_test_util.calculate_list_map( + second_reversed_expr.project(["reversed_event.reversed_action.doc_id"]), + self, + options=self._get_calculate_options(use_string_view), + ) + self.assertAllEqual( + result["reversed_event.reversed_action.doc_id"], + [ + [[[b"b"], [b"a"], []], [[b"c"]], [[b"f"], [b"e"]]], + [[[b"g"], [b"j"]], [[b"i"], [b"h"]]], + ], + ) + if use_string_view: + self._check_string_view() + + @parameterized.named_parameters(("string_view", True), ("no_string_view", False)) + def test_transformed_field_values_with_multiple_transforms(self, use_string_view): + expr = proto_test_util._get_expression_from_session_empty_user_info() + reversed_events_expr = proto.create_transformed_field( + expr, path.Path(["event"]), "reversed_event", _reverse_values + ) + reversed_events_again_expr = proto.create_transformed_field( + reversed_events_expr, + path.Path(["reversed_event"]), + "reversed_reversed_event", + _reverse_values, + ) + + result = expression_test_util.calculate_list_map( + reversed_events_again_expr.project( + ["reversed_reversed_event.action.doc_id"] + ), + self, + options=self._get_calculate_options(use_string_view), + ) + self.assertAllEqual( + result["reversed_reversed_event.action.doc_id"], + [ + [[[b"a"], [b"b"]], [[b"c"], []], [[b"e"], [b"f"]]], + [[[b"g"]], [[b"h"], [b"i"], [b"j"]]], + ], + ) + if use_string_view: + self._check_string_view() def _reverse_values(parent_indices, values): - """A simple function for testing create_transformed_field.""" - return parent_indices, tf.reverse(values, axis=[-1]) + """A simple function for testing create_transformed_field.""" + return parent_indices, tf.reverse(values, axis=[-1]) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/proto_test_util.py b/struct2tensor/expression_impl/proto_test_util.py index 46d7795..80f21ef 100644 --- a/struct2tensor/expression_impl/proto_test_util.py +++ b/struct2tensor/expression_impl/proto_test_util.py @@ -21,20 +21,20 @@ def text_to_tensor(text_list, example_proto_clz): - as_protos = [text_format.Parse(x, example_proto_clz()) for x in text_list] - serialized = [x.SerializeToString() for x in as_protos] - return tf.constant(serialized) + as_protos = [text_format.Parse(x, example_proto_clz()) for x in text_list] + serialized = [x.SerializeToString() for x in as_protos] + return tf.constant(serialized) def text_to_expression(text_list, example_proto_clz): - """Create an expression from a list of text format protos.""" - return proto.create_expression_from_proto( - text_to_tensor(text_list, example_proto_clz), - example_proto_clz().DESCRIPTOR) + """Create an expression from a list of text format protos.""" + return proto.create_expression_from_proto( + text_to_tensor(text_list, example_proto_clz), example_proto_clz().DESCRIPTOR + ) def _get_expression_from_session_empty_user_info(): - r"""Run create_root_prensor on a very deep tree. + r"""Run create_root_prensor on a very deep tree. In addition, the user_info is empty. @@ -54,8 +54,9 @@ def _get_expression_from_session_empty_user_info(): Returns: A RootPrensor with the above structure. """ - return text_to_expression([ - """ + return text_to_expression( + [ + """ event:{ action:{ doc_id:"a" @@ -81,7 +82,8 @@ def _get_expression_from_session_empty_user_info(): doc_id:"f" } event_id: "C" - }""", """ + }""", + """ event:{ action:{ doc_id:"g" @@ -98,5 +100,7 @@ def _get_expression_from_session_empty_user_info(): action:{ doc_id:"j" } - }""" - ], test_pb2.Session) + }""", + ], + test_pb2.Session, + ) diff --git a/struct2tensor/expression_impl/raw_parquet_dataset_test.py b/struct2tensor/expression_impl/raw_parquet_dataset_test.py index 66e3e7b..58481dc 100644 --- a/struct2tensor/expression_impl/raw_parquet_dataset_test.py +++ b/struct2tensor/expression_impl/raw_parquet_dataset_test.py @@ -16,1012 +16,1139 @@ import tensorflow as tf from absl.testing import absltest from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor.expression_impl import parquet class ParquetDatasetTestBase(tf.test.TestCase): - """Test base for testing _RawParquetDataset. - - This wraps tensorflow dataset's iterator and get_next, so that it can run both - eagerly and in graph mode. - """ - - def setUp(self): - super().setUp() - self._test_filenames = [ - "struct2tensor/testdata/parquet_testdata/dremel_example.parquet" - ] - self._datatype_test_filenames = [ - "struct2tensor/testdata/parquet_testdata/all_data_types.parquet" - ] - self._rowgroup_test_filenames = [ - "struct2tensor/testdata/parquet_testdata/dremel_example_two_row_groups.parquet" - ] - - def _assertTensorsEqual(self, result, expected, assert_items_equal): - """Checks that two list of prensors (collection of tensors) are equal. - - Args: - result: first list of prensors - expected: second list of prensors - assert_items_equal: determines if order of result and expected matters - """ - if assert_items_equal: - self.assertCountEqual(result, expected) - return - - for pren1, pren2 in zip(result, expected): - self.assertEqual(len(pren1), len(pren2)) - # we know the first ele is always scalar - self.assertEqual(pren1[0], pren2[0]) - for (tensor1, tensor2) in zip(pren1, pren2): - self.assertAllEqual(tensor1, tensor2) - - def _getNext(self, dataset, requires_initialization=False): - """Returns a callable that returns the next element of the dataset. - - Example use: - ```python - # In both graph and eager modes - dataset = ... - get_next = self._getNext(dataset) - result = self.evaluate(get_next()) - ``` - - Args: - dataset: A dataset whose elements will be returned. - requires_initialization: Indicates that when the test is executed in graph - mode, it should use an initializable iterator to iterate through the - dataset (e.g. when it contains stateful nodes). Defaults to False. - - Returns: - A callable that returns the next element of `dataset`. Any `TensorArray` - objects `dataset` outputs are stacked. - """ + """Test base for testing _RawParquetDataset. - def ta_wrapper(gn): - - def _wrapper(): - r = gn() - if isinstance(r, tf.TensorArray): - return r.melt() - return r - - return _wrapper - - building_function = tf.compat.v1.get_default_graph().building_function - if tf.executing_eagerly() or building_function: - iterator = iter(dataset) - return ta_wrapper(iterator._next_internal) - if requires_initialization: - iterator = tf.compat.v1.data.make_initializable_iterator(dataset) - self.evaluate(iterator.initializer) - else: - iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) - get_next = iterator.get_next() - return ta_wrapper(lambda: get_next) - - def assertDatasetProduces(self, - dataset, - expected_output=None, - expected_shapes=None, - expected_error=None, - requires_initialization=False, - num_test_iterations=1, - assert_items_equal=False, - expected_error_iter=1): - """Asserts that a dataset produces the expected output / error. - - Args: - dataset: A dataset to check for the expected output / error. - expected_output: A list of elements that the dataset is expected to - produce. - expected_shapes: A list of TensorShapes which is expected to match - output_shapes of dataset. - expected_error: A tuple `(type, message)` identifying the expected error - `dataset` should raise. The `type` should match the expected exception - type, while `message` should either be a regular expression that is - expected to match the error message partially. - requires_initialization: Indicates that when the test is executed in graph - mode, it should use an initializable iterator to iterate through the - dataset (e.g. when it contains stateful nodes). Defaults to False. - num_test_iterations: Number of times `dataset` will be iterated. Defaults - to 2. - assert_items_equal: Tests expected_output has (only) the same elements - regardless of order. - expected_error_iter: How many times to iterate before expecting an error, - if an error is expected. + This wraps tensorflow dataset's iterator and get_next, so that it can run both + eagerly and in graph mode. """ - self.assertTrue( - expected_error is not None or expected_output is not None, - "Exactly one of expected_output or expected error should be provided.") - if expected_error: - self.assertIsNone( - expected_output, - "Exactly one of expected_output or expected error should be provided." - ) - with self.assertRaisesRegex(expected_error[0], expected_error[1]): - get_next = self._getNext( - dataset, requires_initialization=requires_initialization) - for _ in range(expected_error_iter): - self.evaluate(get_next()) - return - if expected_shapes: - self.assertEqual(expected_shapes, tf.data.get_output_shapes(dataset)) - self.assertGreater(num_test_iterations, 0) - for _ in range(num_test_iterations): - get_next = self._getNext( - dataset, requires_initialization=requires_initialization) - result = [] - for _ in range(len(expected_output)): - res = self.evaluate(get_next()) - result.append(res) - self._assertTensorsEqual(result, expected_output, assert_items_equal) - with self.assertRaises(tf.errors.OutOfRangeError): - self.evaluate(get_next()) - with self.assertRaises(tf.errors.OutOfRangeError): - self.evaluate(get_next()) + + def setUp(self): + super().setUp() + self._test_filenames = [ + "struct2tensor/testdata/parquet_testdata/dremel_example.parquet" + ] + self._datatype_test_filenames = [ + "struct2tensor/testdata/parquet_testdata/all_data_types.parquet" + ] + self._rowgroup_test_filenames = [ + "struct2tensor/testdata/parquet_testdata/dremel_example_two_row_groups.parquet" + ] + + def _assertTensorsEqual(self, result, expected, assert_items_equal): + """Checks that two list of prensors (collection of tensors) are equal. + + Args: + result: first list of prensors + expected: second list of prensors + assert_items_equal: determines if order of result and expected matters + """ + if assert_items_equal: + self.assertCountEqual(result, expected) + return + + for pren1, pren2 in zip(result, expected): + self.assertEqual(len(pren1), len(pren2)) + # we know the first ele is always scalar + self.assertEqual(pren1[0], pren2[0]) + for tensor1, tensor2 in zip(pren1, pren2): + self.assertAllEqual(tensor1, tensor2) + + def _getNext(self, dataset, requires_initialization=False): + """Returns a callable that returns the next element of the dataset. + + Example use: + ```python + # In both graph and eager modes + dataset = ... + get_next = self._getNext(dataset) + result = self.evaluate(get_next()) + ``` + + Args: + dataset: A dataset whose elements will be returned. + requires_initialization: Indicates that when the test is executed in graph + mode, it should use an initializable iterator to iterate through the + dataset (e.g. when it contains stateful nodes). Defaults to False. + + Returns: + A callable that returns the next element of `dataset`. Any `TensorArray` + objects `dataset` outputs are stacked. + """ + + def ta_wrapper(gn): + def _wrapper(): + r = gn() + if isinstance(r, tf.TensorArray): + return r.melt() + return r + + return _wrapper + + building_function = tf.compat.v1.get_default_graph().building_function + if tf.executing_eagerly() or building_function: + iterator = iter(dataset) + return ta_wrapper(iterator._next_internal) + if requires_initialization: + iterator = tf.compat.v1.data.make_initializable_iterator(dataset) + self.evaluate(iterator.initializer) + else: + iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) + get_next = iterator.get_next() + return ta_wrapper(lambda: get_next) + + def assertDatasetProduces( + self, + dataset, + expected_output=None, + expected_shapes=None, + expected_error=None, + requires_initialization=False, + num_test_iterations=1, + assert_items_equal=False, + expected_error_iter=1, + ): + """Asserts that a dataset produces the expected output / error. + + Args: + dataset: A dataset to check for the expected output / error. + expected_output: A list of elements that the dataset is expected to + produce. + expected_shapes: A list of TensorShapes which is expected to match + output_shapes of dataset. + expected_error: A tuple `(type, message)` identifying the expected error + `dataset` should raise. The `type` should match the expected exception + type, while `message` should either be a regular expression that is + expected to match the error message partially. + requires_initialization: Indicates that when the test is executed in graph + mode, it should use an initializable iterator to iterate through the + dataset (e.g. when it contains stateful nodes). Defaults to False. + num_test_iterations: Number of times `dataset` will be iterated. Defaults + to 2. + assert_items_equal: Tests expected_output has (only) the same elements + regardless of order. + expected_error_iter: How many times to iterate before expecting an error, + if an error is expected. + """ + self.assertTrue( + expected_error is not None or expected_output is not None, + "Exactly one of expected_output or expected error should be provided.", + ) + if expected_error: + self.assertIsNone( + expected_output, + "Exactly one of expected_output or expected error should be provided.", + ) + with self.assertRaisesRegex(expected_error[0], expected_error[1]): + get_next = self._getNext( + dataset, requires_initialization=requires_initialization + ) + for _ in range(expected_error_iter): + self.evaluate(get_next()) + return + if expected_shapes: + self.assertEqual(expected_shapes, tf.data.get_output_shapes(dataset)) + self.assertGreater(num_test_iterations, 0) + for _ in range(num_test_iterations): + get_next = self._getNext( + dataset, requires_initialization=requires_initialization + ) + result = [] + for _ in range(len(expected_output)): + res = self.evaluate(get_next()) + result.append(res) + self._assertTensorsEqual(result, expected_output, assert_items_equal) + with self.assertRaises(tf.errors.OutOfRangeError): + self.evaluate(get_next()) + with self.assertRaises(tf.errors.OutOfRangeError): + self.evaluate(get_next()) @test_util.run_all_in_graph_and_eager_modes class ParquetDatasetForTestingOpTest(ParquetDatasetTestBase): - """Tests for constructing parquet datasets and reading columns. - - Input file schema: - message Document { - required int64 DocId; - optional group Links { - repeated int64 Backward; - repeated int64 Forward; - } - repeated group Name { - repeated group Language { - required binary Code (UTF8); - optional binary Country (UTF8); + """Tests for constructing parquet datasets and reading columns. + + Input file schema: + message Document { + required int64 DocId; + optional group Links { + repeated int64 Backward; + repeated int64 Forward; } - optional binary Url (UTF8); - } - } - - Input file contents: - Document - DocId: 10 - Links - Forward: 20 - Forward: 40 - Forward: 60 - Name - Language - Code: 'en-us' - Country: 'us' - Language - Code: 'en' - Url: 'http://A' - Name - Url: 'http://B' - Name - Language - Code: 'en-gb' - Country: 'gb' - Document - DocId: 20 - Links - Backward: 10 - Backward: 30 - Forward: 80 - Name - Url: 'http://C' - """ - - def testInvalidParentIndexPaths(self): - """Tests that wrong parent_index_paths order will throw an error.""" - with self.assertRaisesRegex( - tf.errors.InvalidArgumentError, - "parent_index_paths is not aligned with value_paths"): - pq_ds = parquet._RawParquetDataset( - filenames=self._test_filenames, - value_paths=["DocId", "Name.Language.Code", "Name.Language.Country"], - value_dtypes=( - tf.int64, - tf.string, - tf.string, - ), - parent_index_paths=[ - "Name.Language.Code", "Name.Language.Code", "Name.Language.Code", - "DocId", "Name.Language.Country" - ], - path_index=[0, 1, 2, 0, 2], - batch_size=1) - get_next = self._getNext(pq_ds, True) - self.evaluate(get_next()) - - def testCreate(self): - """Tests the creation of a dataset with valid parameters.""" - parquet._RawParquetDataset( - filenames=self._test_filenames, - value_paths=["DocId"], - value_dtypes=(tf.int64,), - parent_index_paths=["DocId"], - path_index=[0], - batch_size=1) - - def testFirstParentIndexOnly(self): - """Tests only reqruest one parent index. - - Even when the path contains more than one. i.e. Name.Language.Code should - have 4 parent indices (including the root). - """ - pq_ds = parquet._RawParquetDataset( - filenames=self._test_filenames, - value_paths=["Name.Language.Code"], - value_dtypes=(tf.string,), - parent_index_paths=["Name.Language.Code"], - path_index=[0], - batch_size=1) - self.assertDatasetProduces( - pq_ds, - expected_output=[(1, [0, 0, 0], [b"en-us", b"en", b"en-gb"]), - (1, [0], [])]) - - def testDatasetShuffle(self): - """Tests that the dataset supports shuffling of the dataset.""" - pq_ds = parquet._RawParquetDataset( - filenames=self._test_filenames, - value_paths=["DocId"], - value_dtypes=(tf.int64,), - parent_index_paths=["DocId"], - path_index=[0], - batch_size=1) - pq_ds = pq_ds.shuffle(10) - self.assertDatasetProduces( - pq_ds, - expected_output=[(1, [0], [10]), (1, [0], [20])], - assert_items_equal=True) - - def testBatchEvenlyDivisible_ReadsTwoMessages(self): - """Tests batch size that evenly divides the total number of messages.""" - pq_ds = parquet._RawParquetDataset( - filenames=self._test_filenames, - value_paths=["DocId"], - value_dtypes=(tf.int64,), - parent_index_paths=["DocId"], - path_index=[0], - batch_size=2) - self.assertDatasetProduces(pq_ds, expected_output=[(2, [0, 1], [10, 20])]) - - def testBatchLargerThanTotal_ReadsTwoMessages(self): - """Tests batch size that is larger than the total number of messages. - - Since the batch size is larger, the output would be less than batch size. - """ - pq_ds = parquet._RawParquetDataset( - filenames=self._test_filenames, - value_paths=["DocId"], - value_dtypes=(tf.int64,), - parent_index_paths=["DocId"], - path_index=[0], - batch_size=5) - self.assertDatasetProduces(pq_ds, expected_output=[(2, [0, 1], [10, 20])]) - - def testBatchEvenlyDivisibleContainsNones_ReadsTwoMessages(self): - """Tests batch size that evenly divides the total number of messages. - - And that the messages contains None values. - """ - pq_ds = parquet._RawParquetDataset( - filenames=self._test_filenames, - value_paths=["Name.Language.Country"], - value_dtypes=(tf.string,), - parent_index_paths=[ - "Name.Language.Country", "Name.Language.Country", - "Name.Language.Country" - ], - path_index=[0, 1, 2], - batch_size=2) - self.assertDatasetProduces( - pq_ds, - expected_output=[(2, [0, 0, 0, 1], [0, 0, 2], [0, 2], [b"us", b"gb"])]) - - def testBatchLargerThanTotalContainsNones_ReadsTwoMessages(self): - """Tests batch size that is larger than the total number of messages. - - And that the messages contains None values. - """ - pq_ds = parquet._RawParquetDataset( - filenames=self._test_filenames, - value_paths=["Name.Language.Country"], - value_dtypes=(tf.string,), - parent_index_paths=[ - "Name.Language.Country", "Name.Language.Country", - "Name.Language.Country" - ], - path_index=[0, 1, 2], - batch_size=5) - self.assertDatasetProduces( - pq_ds, - expected_output=[(2, [0, 0, 0, 1], [0, 0, 2], [0, 2], [b"us", b"gb"])]) - - def testMultipleColumns(self): - """Tests that the dataset supports multiple columns.""" - pq_ds = parquet._RawParquetDataset( - filenames=self._test_filenames, - value_paths=["DocId", "Name.Language.Code", "Name.Language.Country"], - value_dtypes=( - tf.int64, - tf.string, - tf.string, - ), - parent_index_paths=[ - "DocId", "Name.Language.Code", "Name.Language.Code", - "Name.Language.Code", "Name.Language.Country" - ], - path_index=[0, 0, 1, 2, 2], - batch_size=1) - self.assertDatasetProduces( - pq_ds, - expected_output=[(1, [0], [10], [0, 0, 0], [0, 0, 2], [0, 1, 2], - [b"en-us", b"en", b"en-gb"], [0, 2], [b"us", b"gb"]), - (1, [0], [20], [0], [], [], [], [], [])]) - - def testMultipleFiles(self): - """Tests that the dataset supports multiple files.""" - pq_ds = parquet._RawParquetDataset( - filenames=[self._test_filenames[0], self._rowgroup_test_filenames[0]], - value_paths=["DocId"], - value_dtypes=(tf.int64,), - parent_index_paths=["DocId"], - path_index=[0], - batch_size=1) - self.assertDatasetProduces( - pq_ds, - expected_output=[(1, [0], [10]), (1, [0], [20]), (1, [0], [10]), - (1, [0], [20]), (1, [0], [30]), (1, [0], [40])]) - - def testMultipleFilesLargeBatchSize(self): - """Tests that the dataset supports multiple files and large batch size.""" - # TODO(andylou) We don't want this behavior in the future. - # We eventually want the dataset to grab batches across files. - pq_ds = parquet._RawParquetDataset( - filenames=[self._test_filenames[0], self._rowgroup_test_filenames[0]], - value_paths=["DocId"], - value_dtypes=(tf.int64,), - parent_index_paths=["DocId"], - path_index=[0], - batch_size=5) - self.assertDatasetProduces( - pq_ds, - expected_output=[(2, [0, 1], [10, 20]), - (4, [0, 1, 2, 3], [10, 20, 30, 40])]) - - -@test_util.run_all_in_graph_and_eager_modes -class ParquetDatasetForTestingOpDataTypesTest(ParquetDatasetTestBase): - """Tests for testing that all types are handled in tensorflow. - - Input file schema: - message Document { - required boolean bool; - required int32 int32; - required int64 int64; - required float float; - required double double; - required binary byte_array (UTF8); - } - Input file contents: - Document - bool: false - int32: 10 - int64: 20 - float: 30.0 - double: 40.0 - byte_array: "fifty" - """ - - def testGetNext_HandlesBool(self): - """Tests that bool is translated from parquet file to tensor.""" - pq_ds = parquet._RawParquetDataset( - filenames=self._datatype_test_filenames, - value_paths=["bool"], - value_dtypes=(tf.bool,), - parent_index_paths=["bool"], - path_index=[0], - batch_size=1) - self.assertDatasetProduces(pq_ds, expected_output=[(1, [0], [False])]) - - def testGetNext_HandlesInt32(self): - """Tests that int32 is translated from parquet file to tensor.""" - pq_ds = parquet._RawParquetDataset( - filenames=self._datatype_test_filenames, - value_paths=["int32"], - value_dtypes=(tf.int32,), - parent_index_paths=["int32"], - path_index=[0], - batch_size=1) - self.assertDatasetProduces(pq_ds, expected_output=[(1, [0], [10])]) - - def testGetNext_HandlesInt64(self): - """Tests that int64 is translated from parquet file to tensor.""" - pq_ds = parquet._RawParquetDataset( - filenames=self._datatype_test_filenames, - value_paths=["int64"], - value_dtypes=(tf.int64,), - parent_index_paths=["int64"], - path_index=[0], - batch_size=1) - self.assertDatasetProduces(pq_ds, expected_output=[(1, [0], [20])]) - - def testGetNext_HandlesFloat(self): - """Tests that float is translated from parquet file to tensor.""" - pq_ds = parquet._RawParquetDataset( - filenames=self._datatype_test_filenames, - value_paths=["float"], - value_dtypes=(tf.float32,), - parent_index_paths=["float"], - path_index=[0], - batch_size=1) - self.assertDatasetProduces(pq_ds, expected_output=[(1, [0], [30.0])]) - - def testGetNext_HandlesDouble(self): - """Tests that double is translated from parquet file to tensor.""" - pq_ds = parquet._RawParquetDataset( - filenames=self._datatype_test_filenames, - value_paths=["double"], - value_dtypes=(tf.double,), - parent_index_paths=["double"], - path_index=[0], - batch_size=1) - self.assertDatasetProduces(pq_ds, expected_output=[(1, [0], [40.0])]) - - def testGetNext_HandlesString(self): - """Tests that string is translated from parquet file to tensor.""" - pq_ds = parquet._RawParquetDataset( - filenames=self._datatype_test_filenames, - value_paths=["byte_array"], - value_dtypes=(tf.string,), - parent_index_paths=["byte_array"], - path_index=[0], - batch_size=1) - self.assertDatasetProduces(pq_ds, expected_output=[(1, [0], [b"fifty"])]) - - -@test_util.run_all_in_graph_and_eager_modes -class ParquetDatasetForTestingOpRowGroupTest(ParquetDatasetTestBase): - """Tests that multiple row groups are handled. - - This file is the dremel_example.parquet with 2x of the same values - Input file schema: - message Document { - required int64 DocId; - optional group Links { - repeated int64 Backward; - repeated int64 Forward; - } - repeated group Name { - repeated group Language { - required binary Code (UTF8); - optional binary Country (UTF8); + repeated group Name { + repeated group Language { + required binary Code (UTF8); + optional binary Country (UTF8); + } + optional binary Url (UTF8); } - optional binary Url (UTF8); } - } - Input file contents: - RowGroup1: - Document - DocId: 10 - Links - Forward: 20 - Forward: 40 - Forward: 60 - Name - Language - Code: 'en-us' - Country: 'us' - Language - Code: 'en' - Url: 'http://A' - Name - Url: 'http://B' - Name - Language - Code: 'en-gb' - Country: 'gb' - Document - DocId: 20 - Links - Backward: 10 - Backward: 30 - Forward: 80 - Name - Url: 'http://C' - RowGroup2: - Document - DocId: 30 - Links - Forward: 200 - Forward: 400 - Forward: 600 - Name - Language - Code: 'en-us2' - Country: 'us2' - Language - Code: 'en2' - Url: 'http://A2' - Name - Url: 'http://B2' - Name - Language - Code: 'en-gb2' - Country: 'gb2' - Document - DocId: 40 - Links - Backward: 100 - Backward: 300 - Forward: 800 - Name - Url: 'http://C2' - """ - - def testTwoRowGroupsAndDefaultBatchSize(self): - """Tests default batch size with two row groups. - - Input: - Rowgroup0: + + Input file contents: Document DocId: 10 - Document - DocId: 20 - RowGroup1: - Document - DocId: 30 - Document - DocId: 40 - """ - pq_ds = parquet._RawParquetDataset( - filenames=self._rowgroup_test_filenames, - value_paths=["DocId"], - value_dtypes=(tf.int64,), - parent_index_paths=["DocId"], - path_index=[0], - batch_size=1) - self.assertDatasetProduces( - pq_ds, - expected_output=[(1, [0], [10]), (1, [0], [20]), (1, [0], [30]), - (1, [0], [40])]) - - def testTwoRowGroupsAndDefaultBatchSizeContainsNones(self): - """Tests default batch size with two row groups with None values. - - Input: - RowGroup0: - Document + Links + Forward: 20 + Forward: 40 + Forward: 60 Name Language Code: 'en-us' + Country: 'us' Language Code: 'en' + Url: 'http://A' Name + Url: 'http://B' Name Language Code: 'en-gb' + Country: 'gb' Document + DocId: 20 + Links + Backward: 10 + Backward: 30 + Forward: 80 Name - RowGroup1: - Document - Name - Language - Code: 'en-us2' - Language - Code: 'en2' - Name - Name - Language - Code: 'en-gb2' - Document - Name + Url: 'http://C' """ - pq_ds = parquet._RawParquetDataset( - filenames=self._rowgroup_test_filenames, - value_paths=["Name.Language.Code"], - value_dtypes=(tf.string,), - parent_index_paths=[ - "Name.Language.Code", "Name.Language.Code", "Name.Language.Code" - ], - path_index=[0, 1, 2], - batch_size=1) - self.assertDatasetProduces( - pq_ds, - expected_output=[ - (1, [0, 0, 0], [0, 0, 2], [0, 1, 2], [b"en-us", b"en", b"en-gb"]), - (1, [0], [], [], []), - (1, [0, 0, 0], [0, 0, 2], [0, 1, 2], [b"en-us2", b"en2", - b"en-gb2"]), - (1, [0], [], [], []) - ]) - - def testTwoRowGroupsAndEqualBatchSize(self): - """Tests batch size == row group size, with two row groups. - - Input: - Rowgroup0: - Document - DocId: 10 - Document - DocId: 20 - RowGroup1: - Document - DocId: 30 + + def testInvalidParentIndexPaths(self): + """Tests that wrong parent_index_paths order will throw an error.""" + with self.assertRaisesRegex( + tf.errors.InvalidArgumentError, + "parent_index_paths is not aligned with value_paths", + ): + pq_ds = parquet._RawParquetDataset( + filenames=self._test_filenames, + value_paths=["DocId", "Name.Language.Code", "Name.Language.Country"], + value_dtypes=( + tf.int64, + tf.string, + tf.string, + ), + parent_index_paths=[ + "Name.Language.Code", + "Name.Language.Code", + "Name.Language.Code", + "DocId", + "Name.Language.Country", + ], + path_index=[0, 1, 2, 0, 2], + batch_size=1, + ) + get_next = self._getNext(pq_ds, True) + self.evaluate(get_next()) + + def testCreate(self): + """Tests the creation of a dataset with valid parameters.""" + parquet._RawParquetDataset( + filenames=self._test_filenames, + value_paths=["DocId"], + value_dtypes=(tf.int64,), + parent_index_paths=["DocId"], + path_index=[0], + batch_size=1, + ) + + def testFirstParentIndexOnly(self): + """Tests only reqruest one parent index. + + Even when the path contains more than one. i.e. Name.Language.Code should + have 4 parent indices (including the root). + """ + pq_ds = parquet._RawParquetDataset( + filenames=self._test_filenames, + value_paths=["Name.Language.Code"], + value_dtypes=(tf.string,), + parent_index_paths=["Name.Language.Code"], + path_index=[0], + batch_size=1, + ) + self.assertDatasetProduces( + pq_ds, + expected_output=[(1, [0, 0, 0], [b"en-us", b"en", b"en-gb"]), (1, [0], [])], + ) + + def testDatasetShuffle(self): + """Tests that the dataset supports shuffling of the dataset.""" + pq_ds = parquet._RawParquetDataset( + filenames=self._test_filenames, + value_paths=["DocId"], + value_dtypes=(tf.int64,), + parent_index_paths=["DocId"], + path_index=[0], + batch_size=1, + ) + pq_ds = pq_ds.shuffle(10) + self.assertDatasetProduces( + pq_ds, + expected_output=[(1, [0], [10]), (1, [0], [20])], + assert_items_equal=True, + ) + + def testBatchEvenlyDivisible_ReadsTwoMessages(self): + """Tests batch size that evenly divides the total number of messages.""" + pq_ds = parquet._RawParquetDataset( + filenames=self._test_filenames, + value_paths=["DocId"], + value_dtypes=(tf.int64,), + parent_index_paths=["DocId"], + path_index=[0], + batch_size=2, + ) + self.assertDatasetProduces(pq_ds, expected_output=[(2, [0, 1], [10, 20])]) + + def testBatchLargerThanTotal_ReadsTwoMessages(self): + """Tests batch size that is larger than the total number of messages. + + Since the batch size is larger, the output would be less than batch size. + """ + pq_ds = parquet._RawParquetDataset( + filenames=self._test_filenames, + value_paths=["DocId"], + value_dtypes=(tf.int64,), + parent_index_paths=["DocId"], + path_index=[0], + batch_size=5, + ) + self.assertDatasetProduces(pq_ds, expected_output=[(2, [0, 1], [10, 20])]) + + def testBatchEvenlyDivisibleContainsNones_ReadsTwoMessages(self): + """Tests batch size that evenly divides the total number of messages. + + And that the messages contains None values. + """ + pq_ds = parquet._RawParquetDataset( + filenames=self._test_filenames, + value_paths=["Name.Language.Country"], + value_dtypes=(tf.string,), + parent_index_paths=[ + "Name.Language.Country", + "Name.Language.Country", + "Name.Language.Country", + ], + path_index=[0, 1, 2], + batch_size=2, + ) + self.assertDatasetProduces( + pq_ds, + expected_output=[(2, [0, 0, 0, 1], [0, 0, 2], [0, 2], [b"us", b"gb"])], + ) + + def testBatchLargerThanTotalContainsNones_ReadsTwoMessages(self): + """Tests batch size that is larger than the total number of messages. + + And that the messages contains None values. + """ + pq_ds = parquet._RawParquetDataset( + filenames=self._test_filenames, + value_paths=["Name.Language.Country"], + value_dtypes=(tf.string,), + parent_index_paths=[ + "Name.Language.Country", + "Name.Language.Country", + "Name.Language.Country", + ], + path_index=[0, 1, 2], + batch_size=5, + ) + self.assertDatasetProduces( + pq_ds, + expected_output=[(2, [0, 0, 0, 1], [0, 0, 2], [0, 2], [b"us", b"gb"])], + ) + + def testMultipleColumns(self): + """Tests that the dataset supports multiple columns.""" + pq_ds = parquet._RawParquetDataset( + filenames=self._test_filenames, + value_paths=["DocId", "Name.Language.Code", "Name.Language.Country"], + value_dtypes=( + tf.int64, + tf.string, + tf.string, + ), + parent_index_paths=[ + "DocId", + "Name.Language.Code", + "Name.Language.Code", + "Name.Language.Code", + "Name.Language.Country", + ], + path_index=[0, 0, 1, 2, 2], + batch_size=1, + ) + self.assertDatasetProduces( + pq_ds, + expected_output=[ + ( + 1, + [0], + [10], + [0, 0, 0], + [0, 0, 2], + [0, 1, 2], + [b"en-us", b"en", b"en-gb"], + [0, 2], + [b"us", b"gb"], + ), + (1, [0], [20], [0], [], [], [], [], []), + ], + ) + + def testMultipleFiles(self): + """Tests that the dataset supports multiple files.""" + pq_ds = parquet._RawParquetDataset( + filenames=[self._test_filenames[0], self._rowgroup_test_filenames[0]], + value_paths=["DocId"], + value_dtypes=(tf.int64,), + parent_index_paths=["DocId"], + path_index=[0], + batch_size=1, + ) + self.assertDatasetProduces( + pq_ds, + expected_output=[ + (1, [0], [10]), + (1, [0], [20]), + (1, [0], [10]), + (1, [0], [20]), + (1, [0], [30]), + (1, [0], [40]), + ], + ) + + def testMultipleFilesLargeBatchSize(self): + """Tests that the dataset supports multiple files and large batch size.""" + # TODO(andylou) We don't want this behavior in the future. + # We eventually want the dataset to grab batches across files. + pq_ds = parquet._RawParquetDataset( + filenames=[self._test_filenames[0], self._rowgroup_test_filenames[0]], + value_paths=["DocId"], + value_dtypes=(tf.int64,), + parent_index_paths=["DocId"], + path_index=[0], + batch_size=5, + ) + self.assertDatasetProduces( + pq_ds, + expected_output=[ + (2, [0, 1], [10, 20]), + (4, [0, 1, 2, 3], [10, 20, 30, 40]), + ], + ) + + +@test_util.run_all_in_graph_and_eager_modes +class ParquetDatasetForTestingOpDataTypesTest(ParquetDatasetTestBase): + """Tests for testing that all types are handled in tensorflow. + + Input file schema: + message Document { + required boolean bool; + required int32 int32; + required int64 int64; + required float float; + required double double; + required binary byte_array (UTF8); + } + Input file contents: Document - DocId: 40 + bool: false + int32: 10 + int64: 20 + float: 30.0 + double: 40.0 + byte_array: "fifty" """ - pq_ds = parquet._RawParquetDataset( - filenames=self._rowgroup_test_filenames, - value_paths=["DocId"], - value_dtypes=(tf.int64,), - parent_index_paths=["DocId"], - path_index=[0], - batch_size=2) - self.assertDatasetProduces( - pq_ds, expected_output=[(2, [0, 1], [10, 20]), (2, [0, 1], [30, 40])]) - - def testTwoRowGroupsAndEqualBatchSizeContainsNones(self): - """Tests batch size == row group size, with two row groups with None values. - - Input: - RowGroup0: - Document - Name - Language - Code: 'en-us' - Language - Code: 'en' - Name - Name - Language - Code: 'en-gb' - Document - Name + + def testGetNext_HandlesBool(self): + """Tests that bool is translated from parquet file to tensor.""" + pq_ds = parquet._RawParquetDataset( + filenames=self._datatype_test_filenames, + value_paths=["bool"], + value_dtypes=(tf.bool,), + parent_index_paths=["bool"], + path_index=[0], + batch_size=1, + ) + self.assertDatasetProduces(pq_ds, expected_output=[(1, [0], [False])]) + + def testGetNext_HandlesInt32(self): + """Tests that int32 is translated from parquet file to tensor.""" + pq_ds = parquet._RawParquetDataset( + filenames=self._datatype_test_filenames, + value_paths=["int32"], + value_dtypes=(tf.int32,), + parent_index_paths=["int32"], + path_index=[0], + batch_size=1, + ) + self.assertDatasetProduces(pq_ds, expected_output=[(1, [0], [10])]) + + def testGetNext_HandlesInt64(self): + """Tests that int64 is translated from parquet file to tensor.""" + pq_ds = parquet._RawParquetDataset( + filenames=self._datatype_test_filenames, + value_paths=["int64"], + value_dtypes=(tf.int64,), + parent_index_paths=["int64"], + path_index=[0], + batch_size=1, + ) + self.assertDatasetProduces(pq_ds, expected_output=[(1, [0], [20])]) + + def testGetNext_HandlesFloat(self): + """Tests that float is translated from parquet file to tensor.""" + pq_ds = parquet._RawParquetDataset( + filenames=self._datatype_test_filenames, + value_paths=["float"], + value_dtypes=(tf.float32,), + parent_index_paths=["float"], + path_index=[0], + batch_size=1, + ) + self.assertDatasetProduces(pq_ds, expected_output=[(1, [0], [30.0])]) + + def testGetNext_HandlesDouble(self): + """Tests that double is translated from parquet file to tensor.""" + pq_ds = parquet._RawParquetDataset( + filenames=self._datatype_test_filenames, + value_paths=["double"], + value_dtypes=(tf.double,), + parent_index_paths=["double"], + path_index=[0], + batch_size=1, + ) + self.assertDatasetProduces(pq_ds, expected_output=[(1, [0], [40.0])]) + + def testGetNext_HandlesString(self): + """Tests that string is translated from parquet file to tensor.""" + pq_ds = parquet._RawParquetDataset( + filenames=self._datatype_test_filenames, + value_paths=["byte_array"], + value_dtypes=(tf.string,), + parent_index_paths=["byte_array"], + path_index=[0], + batch_size=1, + ) + self.assertDatasetProduces(pq_ds, expected_output=[(1, [0], [b"fifty"])]) + + +@test_util.run_all_in_graph_and_eager_modes +class ParquetDatasetForTestingOpRowGroupTest(ParquetDatasetTestBase): + """Tests that multiple row groups are handled. + + This file is the dremel_example.parquet with 2x of the same values + Input file schema: + message Document { + required int64 DocId; + optional group Links { + repeated int64 Backward; + repeated int64 Forward; + } + repeated group Name { + repeated group Language { + required binary Code (UTF8); + optional binary Country (UTF8); + } + optional binary Url (UTF8); + } + } + Input file contents: RowGroup1: - Document - Name - Language - Code: 'en-us2' - Language - Code: 'en2' - Name - Name - Language - Code: 'en-gb2' - Document - Name - """ - pq_ds = parquet._RawParquetDataset( - filenames=self._rowgroup_test_filenames, - value_paths=["Name.Language.Code"], - value_dtypes=(tf.string,), - parent_index_paths=[ - "Name.Language.Code", "Name.Language.Code", "Name.Language.Code" - ], - path_index=[0, 1, 2], - batch_size=2) - self.assertDatasetProduces( - pq_ds, - expected_output=[(2, [0, 0, 0, - 1], [0, 0, 2], [0, 1, - 2], [b"en-us", b"en", b"en-gb"]), - (2, [0, 0, 0, - 1], [0, 0, - 2], [0, 1, - 2], [b"en-us2", b"en2", b"en-gb2"])]) - - def testTwoRowGroupsAndLargerBatchSize(self): - """Tests batch size > row group size, with two row groups. - - Input: - Rowgroup0: Document DocId: 10 - Document - DocId: 20 - RowGroup1: - Document - DocId: 30 - Document - DocId: 40 - """ - pq_ds = parquet._RawParquetDataset( - filenames=self._rowgroup_test_filenames, - value_paths=["DocId"], - value_dtypes=(tf.int64,), - parent_index_paths=["DocId"], - path_index=[0], - batch_size=3) - self.assertDatasetProduces( - pq_ds, expected_output=[(3, [0, 1, 2], [10, 20, 30]), (1, [0], [40])]) - - def testTwoRowGroupsAndLargerBatchSizeContainsNones(self): - """Tests batch size > row group size, with two row groups with None values. - - Input: - RowGroup0: - Document + Links + Forward: 20 + Forward: 40 + Forward: 60 Name Language Code: 'en-us' + Country: 'us' Language Code: 'en' + Url: 'http://A' Name + Url: 'http://B' Name Language Code: 'en-gb' + Country: 'gb' Document + DocId: 20 + Links + Backward: 10 + Backward: 30 + Forward: 80 Name - RowGroup1: + Url: 'http://C' + RowGroup2: Document + DocId: 30 + Links + Forward: 200 + Forward: 400 + Forward: 600 Name Language Code: 'en-us2' + Country: 'us2' Language Code: 'en2' + Url: 'http://A2' Name + Url: 'http://B2' Name Language Code: 'en-gb2' + Country: 'gb2' Document - Name - """ - pq_ds = parquet._RawParquetDataset( - filenames=self._rowgroup_test_filenames, - value_paths=["Name.Language.Code"], - value_dtypes=(tf.string,), - parent_index_paths=[ - "Name.Language.Code", "Name.Language.Code", "Name.Language.Code" - ], - path_index=[0, 1, 2], - batch_size=3) - self.assertDatasetProduces( - pq_ds, - expected_output=[ - (3, [0, 0, 0, 1, 2, 2, 2], [0, 0, 2, 4, 4, 6], [0, 1, 2, 3, 4, 5], - [b"en-us", b"en", b"en-gb", b"en-us2", b"en2", b"en-gb2"]), - (1, [0], [], [], []) - ]) - - def testTwoRowGroupsAndDefaultBatchSizeFirstValueIsNone(self): - """Tests batch size < row group size, where the first value is None. - - This tests that the buffer is still used properly, even if ther value - cached is None. - Input: - Rowgroup0: - Document - Links - Document - Links - Backward: 10 - Backward: 30 - RowGroup1: - Document - Links - Document - Links - Backward: 100 - Backward: 300 - """ - pq_ds = parquet._RawParquetDataset( - filenames=self._rowgroup_test_filenames, - value_paths=["Links.Backward"], - value_dtypes=(tf.int64,), - parent_index_paths=["Links.Backward", "Links.Backward"], - path_index=[0, 1], - batch_size=1) - self.assertDatasetProduces( - pq_ds, - expected_output=[(1, [0], [], []), (1, [0], [0, 0], [10, 30]), - (1, [0], [], []), (1, [0], [0, 0], [100, 300])]) - - def testTwoRowGroupsAndEqualBatchSizeFirstValueIsNone(self): - """Tests batch size == row group size, where the first value is None. - - This tests that the buffer is still used properly, even if the value - cached is None. - Input: - Rowgroup0: - Document - Links - Document - Links - Backward: 10 - Backward: 30 - RowGroup1: - Document - Links - Document - Links - Backward: 100 - Backward: 300 - """ - pq_ds = parquet._RawParquetDataset( - filenames=self._rowgroup_test_filenames, - value_paths=["Links.Backward"], - value_dtypes=(tf.int64,), - parent_index_paths=["Links.Backward", "Links.Backward"], - path_index=[0, 1], - batch_size=2) - self.assertDatasetProduces( - pq_ds, - expected_output=[(2, [0, 1], [1, 1], [10, 30]), - (2, [0, 1], [1, 1], [100, 300])]) - - def testTwoRowGroupsAndLargerBatchSizeFirstValueIsNone(self): - """Tests batch size > row group size, where the first value is None. - - This tests that the buffer is still used properly, even if the value - cached is None. - Input: - Rowgroup0: - Document - Links - Document - Links - Backward: 10 - Backward: 30 - RowGroup1: - Document - Links - Document + DocId: 40 Links Backward: 100 Backward: 300 - """ - pq_ds = parquet._RawParquetDataset( - filenames=self._rowgroup_test_filenames, - value_paths=["Links.Backward"], - value_dtypes=(tf.int64,), - parent_index_paths=["Links.Backward", "Links.Backward"], - path_index=[0, 1], - batch_size=3) - self.assertDatasetProduces( - pq_ds, - expected_output=[(3, [0, 1, 2], [1, 1], [10, 30]), - (1, [0], [0, 0], [100, 300])]) - - def testTwoRowGroupsAndDefaultBatchSizeLargeFirstMessage(self): - """Tests batch size < row group size, where the first message is large. - - Input: - Rowgroup0: - Document - Links - Forward: 20 - Forward: 40 - Forward: 60 - Document - Links - Forward: 80 - RowGroup1: - Document - Links - Forward: 200 - Forward: 400 - Forward: 600 - Document - Links - Forward: 800 - """ - pq_ds = parquet._RawParquetDataset( - filenames=self._rowgroup_test_filenames, - value_paths=["Links.Forward"], - value_dtypes=(tf.int64,), - parent_index_paths=["Links.Forward", "Links.Forward"], - path_index=[0, 1], - batch_size=1) - self.assertDatasetProduces( - pq_ds, - expected_output=[(1, [0], [0, 0, 0], [20, 40, 60]), (1, [0], [0], [80]), - (1, [0], [0, 0, 0], [200, 400, 600]), - (1, [0], [0], [800])]) - - def testTwoRowGroupsAndEqualBatchSizeLargeFirstMessage(self): - """Tests batch size == row group size, where the first message is large. - - Input: - Rowgroup0: - Document - Links - Forward: 20 - Forward: 40 - Forward: 60 - Document - Links - Forward: 80 - RowGroup1: - Document - Links - Forward: 200 - Forward: 400 - Forward: 600 - Document - Links - Forward: 800 - """ - pq_ds = parquet._RawParquetDataset( - filenames=self._rowgroup_test_filenames, - value_paths=["Links.Forward"], - value_dtypes=(tf.int64,), - parent_index_paths=["Links.Forward", "Links.Forward"], - path_index=[0, 1], - batch_size=2) - self.assertDatasetProduces( - pq_ds, - expected_output=[(2, [0, 1], [0, 0, 0, 1], [20, 40, 60, 80]), - (2, [0, 1], [0, 0, 0, 1], [200, 400, 600, 800])]) - - def testTwoRowGroupsAndLargerBatchSizeLargeFirstMessage(self): - """Tests batch size > row group size, where the first message is large. - - Input: - Rowgroup0: - Document - Links - Forward: 20 - Forward: 40 - Forward: 60 - Document - Links - Forward: 80 - RowGroup1: - Document - Links - Forward: 200 - Forward: 400 - Forward: 600 - Document - Links Forward: 800 + Name + Url: 'http://C2' """ - pq_ds = parquet._RawParquetDataset( - filenames=self._rowgroup_test_filenames, - value_paths=["Links.Forward"], - value_dtypes=(tf.int64,), - parent_index_paths=["Links.Forward", "Links.Forward"], - path_index=[0, 1], - batch_size=3) - self.assertDatasetProduces( - pq_ds, - expected_output=[(3, [0, 1, 2], [0, 0, 0, 1, 2, 2, - 2], [20, 40, 60, 80, 200, 400, 600]), - (1, [0], [0], [800])]) - - def testMultipleColumns_TwoRowGroupsAndEqualBatchSize(self): - """Tests that the dataset supports multiple columns.""" - pq_ds = parquet._RawParquetDataset( - filenames=self._rowgroup_test_filenames, - value_paths=["DocId", "Name.Language.Code", "Name.Language.Country"], - value_dtypes=( - tf.int64, - tf.string, - tf.string, - ), - parent_index_paths=[ - "DocId", "Name.Language.Code", "Name.Language.Code", - "Name.Language.Code", "Name.Language.Country" - ], - path_index=[0, 0, 1, 2, 2], - batch_size=2) - self.assertDatasetProduces( - pq_ds, - expected_output=[(2, [0, 1], [10, 20], [0, 0, 0, 1], [0, 0, 2], - [0, 1, 2], [b"en-us", b"en", - b"en-gb"], [0, 2], [b"us", b"gb"]), - (2, [0, 1], [30, 40], [0, 0, 0, 1], [0, 0, 2], - [0, 1, 2], [b"en-us2", b"en2", - b"en-gb2"], [0, 2], [b"us2", b"gb2"])]) + + def testTwoRowGroupsAndDefaultBatchSize(self): + """Tests default batch size with two row groups. + + Input: + Rowgroup0: + Document + DocId: 10 + Document + DocId: 20 + RowGroup1: + Document + DocId: 30 + Document + DocId: 40 + """ + pq_ds = parquet._RawParquetDataset( + filenames=self._rowgroup_test_filenames, + value_paths=["DocId"], + value_dtypes=(tf.int64,), + parent_index_paths=["DocId"], + path_index=[0], + batch_size=1, + ) + self.assertDatasetProduces( + pq_ds, + expected_output=[ + (1, [0], [10]), + (1, [0], [20]), + (1, [0], [30]), + (1, [0], [40]), + ], + ) + + def testTwoRowGroupsAndDefaultBatchSizeContainsNones(self): + """Tests default batch size with two row groups with None values. + + Input: + RowGroup0: + Document + Name + Language + Code: 'en-us' + Language + Code: 'en' + Name + Name + Language + Code: 'en-gb' + Document + Name + RowGroup1: + Document + Name + Language + Code: 'en-us2' + Language + Code: 'en2' + Name + Name + Language + Code: 'en-gb2' + Document + Name + """ + pq_ds = parquet._RawParquetDataset( + filenames=self._rowgroup_test_filenames, + value_paths=["Name.Language.Code"], + value_dtypes=(tf.string,), + parent_index_paths=[ + "Name.Language.Code", + "Name.Language.Code", + "Name.Language.Code", + ], + path_index=[0, 1, 2], + batch_size=1, + ) + self.assertDatasetProduces( + pq_ds, + expected_output=[ + (1, [0, 0, 0], [0, 0, 2], [0, 1, 2], [b"en-us", b"en", b"en-gb"]), + (1, [0], [], [], []), + (1, [0, 0, 0], [0, 0, 2], [0, 1, 2], [b"en-us2", b"en2", b"en-gb2"]), + (1, [0], [], [], []), + ], + ) + + def testTwoRowGroupsAndEqualBatchSize(self): + """Tests batch size == row group size, with two row groups. + + Input: + Rowgroup0: + Document + DocId: 10 + Document + DocId: 20 + RowGroup1: + Document + DocId: 30 + Document + DocId: 40 + """ + pq_ds = parquet._RawParquetDataset( + filenames=self._rowgroup_test_filenames, + value_paths=["DocId"], + value_dtypes=(tf.int64,), + parent_index_paths=["DocId"], + path_index=[0], + batch_size=2, + ) + self.assertDatasetProduces( + pq_ds, expected_output=[(2, [0, 1], [10, 20]), (2, [0, 1], [30, 40])] + ) + + def testTwoRowGroupsAndEqualBatchSizeContainsNones(self): + """Tests batch size == row group size, with two row groups with None values. + + Input: + RowGroup0: + Document + Name + Language + Code: 'en-us' + Language + Code: 'en' + Name + Name + Language + Code: 'en-gb' + Document + Name + RowGroup1: + Document + Name + Language + Code: 'en-us2' + Language + Code: 'en2' + Name + Name + Language + Code: 'en-gb2' + Document + Name + """ + pq_ds = parquet._RawParquetDataset( + filenames=self._rowgroup_test_filenames, + value_paths=["Name.Language.Code"], + value_dtypes=(tf.string,), + parent_index_paths=[ + "Name.Language.Code", + "Name.Language.Code", + "Name.Language.Code", + ], + path_index=[0, 1, 2], + batch_size=2, + ) + self.assertDatasetProduces( + pq_ds, + expected_output=[ + (2, [0, 0, 0, 1], [0, 0, 2], [0, 1, 2], [b"en-us", b"en", b"en-gb"]), + (2, [0, 0, 0, 1], [0, 0, 2], [0, 1, 2], [b"en-us2", b"en2", b"en-gb2"]), + ], + ) + + def testTwoRowGroupsAndLargerBatchSize(self): + """Tests batch size > row group size, with two row groups. + + Input: + Rowgroup0: + Document + DocId: 10 + Document + DocId: 20 + RowGroup1: + Document + DocId: 30 + Document + DocId: 40 + """ + pq_ds = parquet._RawParquetDataset( + filenames=self._rowgroup_test_filenames, + value_paths=["DocId"], + value_dtypes=(tf.int64,), + parent_index_paths=["DocId"], + path_index=[0], + batch_size=3, + ) + self.assertDatasetProduces( + pq_ds, expected_output=[(3, [0, 1, 2], [10, 20, 30]), (1, [0], [40])] + ) + + def testTwoRowGroupsAndLargerBatchSizeContainsNones(self): + """Tests batch size > row group size, with two row groups with None values. + + Input: + RowGroup0: + Document + Name + Language + Code: 'en-us' + Language + Code: 'en' + Name + Name + Language + Code: 'en-gb' + Document + Name + RowGroup1: + Document + Name + Language + Code: 'en-us2' + Language + Code: 'en2' + Name + Name + Language + Code: 'en-gb2' + Document + Name + """ + pq_ds = parquet._RawParquetDataset( + filenames=self._rowgroup_test_filenames, + value_paths=["Name.Language.Code"], + value_dtypes=(tf.string,), + parent_index_paths=[ + "Name.Language.Code", + "Name.Language.Code", + "Name.Language.Code", + ], + path_index=[0, 1, 2], + batch_size=3, + ) + self.assertDatasetProduces( + pq_ds, + expected_output=[ + ( + 3, + [0, 0, 0, 1, 2, 2, 2], + [0, 0, 2, 4, 4, 6], + [0, 1, 2, 3, 4, 5], + [b"en-us", b"en", b"en-gb", b"en-us2", b"en2", b"en-gb2"], + ), + (1, [0], [], [], []), + ], + ) + + def testTwoRowGroupsAndDefaultBatchSizeFirstValueIsNone(self): + """Tests batch size < row group size, where the first value is None. + + This tests that the buffer is still used properly, even if ther value + cached is None. + Input: + Rowgroup0: + Document + Links + Document + Links + Backward: 10 + Backward: 30 + RowGroup1: + Document + Links + Document + Links + Backward: 100 + Backward: 300 + """ + pq_ds = parquet._RawParquetDataset( + filenames=self._rowgroup_test_filenames, + value_paths=["Links.Backward"], + value_dtypes=(tf.int64,), + parent_index_paths=["Links.Backward", "Links.Backward"], + path_index=[0, 1], + batch_size=1, + ) + self.assertDatasetProduces( + pq_ds, + expected_output=[ + (1, [0], [], []), + (1, [0], [0, 0], [10, 30]), + (1, [0], [], []), + (1, [0], [0, 0], [100, 300]), + ], + ) + + def testTwoRowGroupsAndEqualBatchSizeFirstValueIsNone(self): + """Tests batch size == row group size, where the first value is None. + + This tests that the buffer is still used properly, even if the value + cached is None. + Input: + Rowgroup0: + Document + Links + Document + Links + Backward: 10 + Backward: 30 + RowGroup1: + Document + Links + Document + Links + Backward: 100 + Backward: 300 + """ + pq_ds = parquet._RawParquetDataset( + filenames=self._rowgroup_test_filenames, + value_paths=["Links.Backward"], + value_dtypes=(tf.int64,), + parent_index_paths=["Links.Backward", "Links.Backward"], + path_index=[0, 1], + batch_size=2, + ) + self.assertDatasetProduces( + pq_ds, + expected_output=[ + (2, [0, 1], [1, 1], [10, 30]), + (2, [0, 1], [1, 1], [100, 300]), + ], + ) + + def testTwoRowGroupsAndLargerBatchSizeFirstValueIsNone(self): + """Tests batch size > row group size, where the first value is None. + + This tests that the buffer is still used properly, even if the value + cached is None. + Input: + Rowgroup0: + Document + Links + Document + Links + Backward: 10 + Backward: 30 + RowGroup1: + Document + Links + Document + Links + Backward: 100 + Backward: 300 + """ + pq_ds = parquet._RawParquetDataset( + filenames=self._rowgroup_test_filenames, + value_paths=["Links.Backward"], + value_dtypes=(tf.int64,), + parent_index_paths=["Links.Backward", "Links.Backward"], + path_index=[0, 1], + batch_size=3, + ) + self.assertDatasetProduces( + pq_ds, + expected_output=[ + (3, [0, 1, 2], [1, 1], [10, 30]), + (1, [0], [0, 0], [100, 300]), + ], + ) + + def testTwoRowGroupsAndDefaultBatchSizeLargeFirstMessage(self): + """Tests batch size < row group size, where the first message is large. + + Input: + Rowgroup0: + Document + Links + Forward: 20 + Forward: 40 + Forward: 60 + Document + Links + Forward: 80 + RowGroup1: + Document + Links + Forward: 200 + Forward: 400 + Forward: 600 + Document + Links + Forward: 800 + """ + pq_ds = parquet._RawParquetDataset( + filenames=self._rowgroup_test_filenames, + value_paths=["Links.Forward"], + value_dtypes=(tf.int64,), + parent_index_paths=["Links.Forward", "Links.Forward"], + path_index=[0, 1], + batch_size=1, + ) + self.assertDatasetProduces( + pq_ds, + expected_output=[ + (1, [0], [0, 0, 0], [20, 40, 60]), + (1, [0], [0], [80]), + (1, [0], [0, 0, 0], [200, 400, 600]), + (1, [0], [0], [800]), + ], + ) + + def testTwoRowGroupsAndEqualBatchSizeLargeFirstMessage(self): + """Tests batch size == row group size, where the first message is large. + + Input: + Rowgroup0: + Document + Links + Forward: 20 + Forward: 40 + Forward: 60 + Document + Links + Forward: 80 + RowGroup1: + Document + Links + Forward: 200 + Forward: 400 + Forward: 600 + Document + Links + Forward: 800 + """ + pq_ds = parquet._RawParquetDataset( + filenames=self._rowgroup_test_filenames, + value_paths=["Links.Forward"], + value_dtypes=(tf.int64,), + parent_index_paths=["Links.Forward", "Links.Forward"], + path_index=[0, 1], + batch_size=2, + ) + self.assertDatasetProduces( + pq_ds, + expected_output=[ + (2, [0, 1], [0, 0, 0, 1], [20, 40, 60, 80]), + (2, [0, 1], [0, 0, 0, 1], [200, 400, 600, 800]), + ], + ) + + def testTwoRowGroupsAndLargerBatchSizeLargeFirstMessage(self): + """Tests batch size > row group size, where the first message is large. + + Input: + Rowgroup0: + Document + Links + Forward: 20 + Forward: 40 + Forward: 60 + Document + Links + Forward: 80 + RowGroup1: + Document + Links + Forward: 200 + Forward: 400 + Forward: 600 + Document + Links + Forward: 800 + """ + pq_ds = parquet._RawParquetDataset( + filenames=self._rowgroup_test_filenames, + value_paths=["Links.Forward"], + value_dtypes=(tf.int64,), + parent_index_paths=["Links.Forward", "Links.Forward"], + path_index=[0, 1], + batch_size=3, + ) + self.assertDatasetProduces( + pq_ds, + expected_output=[ + (3, [0, 1, 2], [0, 0, 0, 1, 2, 2, 2], [20, 40, 60, 80, 200, 400, 600]), + (1, [0], [0], [800]), + ], + ) + + def testMultipleColumns_TwoRowGroupsAndEqualBatchSize(self): + """Tests that the dataset supports multiple columns.""" + pq_ds = parquet._RawParquetDataset( + filenames=self._rowgroup_test_filenames, + value_paths=["DocId", "Name.Language.Code", "Name.Language.Country"], + value_dtypes=( + tf.int64, + tf.string, + tf.string, + ), + parent_index_paths=[ + "DocId", + "Name.Language.Code", + "Name.Language.Code", + "Name.Language.Code", + "Name.Language.Country", + ], + path_index=[0, 0, 1, 2, 2], + batch_size=2, + ) + self.assertDatasetProduces( + pq_ds, + expected_output=[ + ( + 2, + [0, 1], + [10, 20], + [0, 0, 0, 1], + [0, 0, 2], + [0, 1, 2], + [b"en-us", b"en", b"en-gb"], + [0, 2], + [b"us", b"gb"], + ), + ( + 2, + [0, 1], + [30, 40], + [0, 0, 0, 1], + [0, 0, 2], + [0, 1, 2], + [b"en-us2", b"en2", b"en-gb2"], + [0, 2], + [b"us2", b"gb2"], + ), + ], + ) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/reroot.py b/struct2tensor/expression_impl/reroot.py index 0d2e894..eaa9e9c 100644 --- a/struct2tensor/expression_impl/reroot.py +++ b/struct2tensor/expression_impl/reroot.py @@ -18,6 +18,7 @@ original proto. """ + from typing import FrozenSet, Optional, Sequence import tensorflow as tf @@ -25,148 +26,156 @@ from struct2tensor import calculate_options, expression, expression_add, path, prensor -def reroot(root: expression.Expression, - source_path: path.Path) -> expression.Expression: - """Reroot to a new path, maintaining a input proto index. +def reroot( + root: expression.Expression, source_path: path.Path +) -> expression.Expression: + """Reroot to a new path, maintaining a input proto index. - Similar to root.get_descendant_or_error(source_path): however, this - method retains the ability to get a map to the original index. + Similar to root.get_descendant_or_error(source_path): however, this + method retains the ability to get a map to the original index. - Args: - root: the original root. - source_path: the path to the new root. + Args: + root: the original root. + source_path: the path to the new root. - Returns: - the new root. - """ - new_root = root - for step in source_path.field_list: - new_root = _RerootExpression(new_root, step) - return new_root + Returns: + the new root. + """ + new_root = root + for step in source_path.field_list: + new_root = _RerootExpression(new_root, step) + return new_root -def create_proto_index_field(root: expression.Expression, - new_field_name: path.Step - ) -> expression.Expression: - return expression_add.add_paths( - root, {path.Path([new_field_name]): _InputProtoIndexExpression(root)}) +def create_proto_index_field( + root: expression.Expression, new_field_name: path.Step +) -> expression.Expression: + return expression_add.add_paths( + root, {path.Path([new_field_name]): _InputProtoIndexExpression(root)} + ) class _RerootRootNodeTensor(prensor.RootNodeTensor): - """The reroot root node. + """The reroot root node. - This contains a map from a current index to the original index of a proto. - """ + This contains a map from a current index to the original index of a proto. + """ - def __init__(self, size: tf.Tensor, input_proto_index: tf.Tensor): - super().__init__(size) - self._input_proto_index = input_proto_index + def __init__(self, size: tf.Tensor, input_proto_index: tf.Tensor): + super().__init__(size) + self._input_proto_index = input_proto_index - @property - def input_proto_index(self): - return self._input_proto_index + @property + def input_proto_index(self): + return self._input_proto_index def _get_proto_index_parent_index(node: prensor.RootNodeTensor): - return tf.range(node.size) + return tf.range(node.size) def _get_input_proto_index(node: prensor.RootNodeTensor): - if isinstance(node, _RerootRootNodeTensor): - return node.input_proto_index - return _get_proto_index_parent_index(node) + if isinstance(node, _RerootRootNodeTensor): + return node.input_proto_index + return _get_proto_index_parent_index(node) class _RerootExpression(expression.Expression): - """Reroot to a new path, maintaining a input proto index.""" - - def __init__(self, original_root: expression.Expression, - field_name: path.Step): - super().__init__( - True, None, validate_step_format=original_root.validate_step_format - ) - self._field_name = field_name - self._original_root = original_root - self._new_root = original_root.get_child_or_error(field_name) - if self._new_root.type is not None: - raise ValueError(f"New root must be a message type: {str(self._field_name)}") - # TODO(martinz): Check that the "original root source expression" has a type - # in (_RerootExpression, prensor._ProtoRootExpression) - # To do this, we need a general technique similar to - # expression_add._is_true_source_expression: however, this should also cover - # intermediate operations like "project". - # Since this check is not present, if it should have fired, there will be - # an error when calculate(...) is called. - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._original_root, self._new_root] - - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - [old_root_value, new_root_value] = sources - if isinstance(old_root_value, prensor.RootNodeTensor) and isinstance( - new_root_value, prensor.ChildNodeTensor): - old_input_proto_index = _get_input_proto_index(old_root_value) - # Notice that the "gather" operation is similar to promote. - return _RerootRootNodeTensor( - tf.size(new_root_value.parent_index, out_type=tf.int64), - tf.gather(old_input_proto_index, new_root_value.parent_index)) - raise ValueError("Source types incorrect") - - def calculation_is_identity(self) -> bool: - return False - - def calculation_equal(self, expr: expression.Expression) -> bool: - # Although path can vary, it is not used in the calculation, just to - return isinstance(expr, _RerootExpression) - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - return self._new_root.get_child(field_name) - - def known_field_names(self) -> FrozenSet[path.Step]: - return self._new_root.known_field_names() + """Reroot to a new path, maintaining a input proto index.""" + + def __init__(self, original_root: expression.Expression, field_name: path.Step): + super().__init__( + True, None, validate_step_format=original_root.validate_step_format + ) + self._field_name = field_name + self._original_root = original_root + self._new_root = original_root.get_child_or_error(field_name) + if self._new_root.type is not None: + raise ValueError( + f"New root must be a message type: {str(self._field_name)}" + ) + # TODO(martinz): Check that the "original root source expression" has a type + # in (_RerootExpression, prensor._ProtoRootExpression) + # To do this, we need a general technique similar to + # expression_add._is_true_source_expression: however, this should also cover + # intermediate operations like "project". + # Since this check is not present, if it should have fired, there will be + # an error when calculate(...) is called. + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._original_root, self._new_root] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + [old_root_value, new_root_value] = sources + if isinstance(old_root_value, prensor.RootNodeTensor) and isinstance( + new_root_value, prensor.ChildNodeTensor + ): + old_input_proto_index = _get_input_proto_index(old_root_value) + # Notice that the "gather" operation is similar to promote. + return _RerootRootNodeTensor( + tf.size(new_root_value.parent_index, out_type=tf.int64), + tf.gather(old_input_proto_index, new_root_value.parent_index), + ) + raise ValueError("Source types incorrect") + + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + # Although path can vary, it is not used in the calculation, just to + return isinstance(expr, _RerootExpression) + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + return self._new_root.get_child(field_name) + + def known_field_names(self) -> FrozenSet[path.Step]: + return self._new_root.known_field_names() class _InputProtoIndexExpression(expression.Leaf): - """A proto index expression.""" - - def __init__(self, root: expression.Expression): - """Constructor for proto index expression. - - Args: - root: an expression that must return a RootNodeTensor. - """ - super().__init__(is_repeated=False, my_type=tf.int64) - self._root = root - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._root] - - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - [root_node] = sources - # The following check ensures not just that we can calculate the value, - # but that no "improper" reroots were done. - if isinstance(root_node, prensor.RootNodeTensor): - return prensor.LeafNodeTensor( - _get_proto_index_parent_index(root_node), - _get_input_proto_index(root_node), - is_repeated=False) - raise ValueError( - f"Illegal operation: expected a true root node: got {str(root_node)}") - - def calculation_is_identity(self) -> bool: - return False - - def calculation_equal(self, expr: expression.Expression) -> bool: - # Although path can vary, it is not used in the calculation, just to - return isinstance(expr, _InputProtoIndexExpression) + """A proto index expression.""" + + def __init__(self, root: expression.Expression): + """Constructor for proto index expression. + + Args: + root: an expression that must return a RootNodeTensor. + """ + super().__init__(is_repeated=False, my_type=tf.int64) + self._root = root + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._root] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + [root_node] = sources + # The following check ensures not just that we can calculate the value, + # but that no "improper" reroots were done. + if isinstance(root_node, prensor.RootNodeTensor): + return prensor.LeafNodeTensor( + _get_proto_index_parent_index(root_node), + _get_input_proto_index(root_node), + is_repeated=False, + ) + raise ValueError( + f"Illegal operation: expected a true root node: got {str(root_node)}" + ) + + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + # Although path can vary, it is not used in the calculation, just to + return isinstance(expr, _InputProtoIndexExpression) diff --git a/struct2tensor/expression_impl/reroot_test.py b/struct2tensor/expression_impl/reroot_test.py index ad3a8ff..4d8bad7 100644 --- a/struct2tensor/expression_impl/reroot_test.py +++ b/struct2tensor/expression_impl/reroot_test.py @@ -16,7 +16,7 @@ import tensorflow as tf from absl.testing import absltest from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor import calculate, create_expression, path @@ -26,146 +26,155 @@ @test_util.run_all_in_graph_and_eager_modes class RerootTest(tf.test.TestCase): - - def test_reroot_and_create_proto_index(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - new_root = reroot.reroot(expr, path.Path(["doc"])) - proto_index = reroot.create_proto_index_field( - new_root, "proto_index").get_child("proto_index") - new_field = new_root.get_child("bar") - leaf_node = expression_test_util.calculate_value_slowly(new_field) - proto_index_node = expression_test_util.calculate_value_slowly(proto_index) - - self.assertIsNotNone(new_field) - self.assertTrue(new_field.is_repeated) - self.assertEqual(new_field.type, tf.string) - self.assertTrue(new_field.is_leaf) - self.assertEqual(new_field.known_field_names(), frozenset()) - self.assertEqual(leaf_node.values.dtype, tf.string) - - self.assertIsNotNone(proto_index) - self.assertFalse(proto_index.is_repeated) - self.assertEqual(proto_index.type, tf.int64) - self.assertTrue(proto_index.is_leaf) - self.assertEqual(proto_index.known_field_names(), frozenset()) - - self.assertEqual(proto_index_node.values.dtype, tf.int64) - - self.assertAllEqual([b"a", b"b", b"c", b"d"], leaf_node.values) - self.assertAllEqual([0, 1, 1, 2], leaf_node.parent_index) - self.assertAllEqual([0, 1, 1], proto_index_node.values) - self.assertAllEqual([0, 1, 2], proto_index_node.parent_index) - - def test_reroot_and_create_proto_index_deep(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_deep_prensor()) - new_root = reroot.reroot(expr, path.Path(["event", "doc"])) - proto_index = reroot.create_proto_index_field( - new_root, "proto_index").get_child("proto_index") - new_field = new_root.get_child("bar") - leaf_node = expression_test_util.calculate_value_slowly(new_field) - proto_index_node = expression_test_util.calculate_value_slowly(proto_index) - - self.assertIsNotNone(new_field) - self.assertTrue(new_field.is_repeated) - self.assertEqual(new_field.type, tf.string) - self.assertTrue(new_field.is_leaf) - self.assertEqual(new_field.known_field_names(), frozenset()) - self.assertEqual(leaf_node.values.dtype, tf.string) - - self.assertIsNotNone(proto_index) - self.assertFalse(proto_index.is_repeated) - self.assertEqual(proto_index.type, tf.int64) - self.assertTrue(proto_index.is_leaf) - self.assertEqual(proto_index.known_field_names(), frozenset()) - - self.assertEqual(proto_index_node.values.dtype, tf.int64) - - self.assertAllEqual([b"a", b"b", b"c", b"d"], leaf_node.values) - self.assertAllEqual([0, 1, 1, 2], leaf_node.parent_index) - self.assertAllEqual([0, 1, 1], proto_index_node.values) - self.assertAllEqual([0, 1, 2], proto_index_node.parent_index) - - def test_create_proto_index_directly_reroot_at_action(self): - sessions = [ - """ + def test_reroot_and_create_proto_index(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + new_root = reroot.reroot(expr, path.Path(["doc"])) + proto_index = reroot.create_proto_index_field( + new_root, "proto_index" + ).get_child("proto_index") + new_field = new_root.get_child("bar") + leaf_node = expression_test_util.calculate_value_slowly(new_field) + proto_index_node = expression_test_util.calculate_value_slowly(proto_index) + + self.assertIsNotNone(new_field) + self.assertTrue(new_field.is_repeated) + self.assertEqual(new_field.type, tf.string) + self.assertTrue(new_field.is_leaf) + self.assertEqual(new_field.known_field_names(), frozenset()) + self.assertEqual(leaf_node.values.dtype, tf.string) + + self.assertIsNotNone(proto_index) + self.assertFalse(proto_index.is_repeated) + self.assertEqual(proto_index.type, tf.int64) + self.assertTrue(proto_index.is_leaf) + self.assertEqual(proto_index.known_field_names(), frozenset()) + + self.assertEqual(proto_index_node.values.dtype, tf.int64) + + self.assertAllEqual([b"a", b"b", b"c", b"d"], leaf_node.values) + self.assertAllEqual([0, 1, 1, 2], leaf_node.parent_index) + self.assertAllEqual([0, 1, 1], proto_index_node.values) + self.assertAllEqual([0, 1, 2], proto_index_node.parent_index) + + def test_reroot_and_create_proto_index_deep(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_deep_prensor() + ) + new_root = reroot.reroot(expr, path.Path(["event", "doc"])) + proto_index = reroot.create_proto_index_field( + new_root, "proto_index" + ).get_child("proto_index") + new_field = new_root.get_child("bar") + leaf_node = expression_test_util.calculate_value_slowly(new_field) + proto_index_node = expression_test_util.calculate_value_slowly(proto_index) + + self.assertIsNotNone(new_field) + self.assertTrue(new_field.is_repeated) + self.assertEqual(new_field.type, tf.string) + self.assertTrue(new_field.is_leaf) + self.assertEqual(new_field.known_field_names(), frozenset()) + self.assertEqual(leaf_node.values.dtype, tf.string) + + self.assertIsNotNone(proto_index) + self.assertFalse(proto_index.is_repeated) + self.assertEqual(proto_index.type, tf.int64) + self.assertTrue(proto_index.is_leaf) + self.assertEqual(proto_index.known_field_names(), frozenset()) + + self.assertEqual(proto_index_node.values.dtype, tf.int64) + + self.assertAllEqual([b"a", b"b", b"c", b"d"], leaf_node.values) + self.assertAllEqual([0, 1, 1, 2], leaf_node.parent_index) + self.assertAllEqual([0, 1, 1], proto_index_node.values) + self.assertAllEqual([0, 1, 2], proto_index_node.parent_index) + + def test_create_proto_index_directly_reroot_at_action(self): + sessions = [ + """ event { action {} action {} } event {} event { action {} } - """, "", """ + """, + "", + """ event {} event { action {} action {} } event { } - """ - ] - expr = proto_test_util.text_to_expression(sessions, test_pb2.Session) - reroot_expr = expr.reroot("event.action") - # Reroot with a depth > 1 (all the other cases are depth == 1) - proto_index_directly_reroot_at_action = ( - reroot_expr.create_proto_index("proto_index_directly_reroot_at_action") - .get_child_or_error("proto_index_directly_reroot_at_action")) - - self.assertFalse(proto_index_directly_reroot_at_action.is_repeated) - result = expression_test_util.calculate_value_slowly( - proto_index_directly_reroot_at_action) - self.assertAllEqual(result.parent_index, [0, 1, 2, 3, 4]) - self.assertAllEqual(result.values, [0, 0, 0, 2, 2]) - - def test_create_proto_index_directly_reroot_at_action_sparse_dense(self): - sessions = [ - """ + """, + ] + expr = proto_test_util.text_to_expression(sessions, test_pb2.Session) + reroot_expr = expr.reroot("event.action") + # Reroot with a depth > 1 (all the other cases are depth == 1) + proto_index_directly_reroot_at_action = reroot_expr.create_proto_index( + "proto_index_directly_reroot_at_action" + ).get_child_or_error("proto_index_directly_reroot_at_action") + + self.assertFalse(proto_index_directly_reroot_at_action.is_repeated) + result = expression_test_util.calculate_value_slowly( + proto_index_directly_reroot_at_action + ) + self.assertAllEqual(result.parent_index, [0, 1, 2, 3, 4]) + self.assertAllEqual(result.values, [0, 0, 0, 2, 2]) + + def test_create_proto_index_directly_reroot_at_action_sparse_dense(self): + sessions = [ + """ event { action {} action {} } event {} event { action {} } - """, "", """ + """, + "", + """ event {} event { action {} action {} } event { } - """ - ] - expr = proto_test_util.text_to_expression(sessions, test_pb2.Session) - reroot_expr = expr.reroot("event.action") - # Reroot with a depth > 1 (all the other cases are depth == 1) - [prensor_tree] = calculate.calculate_prensors([ - reroot_expr.create_proto_index("proto_index_directly_reroot_at_action") - ]) - proto_index_node = prensor_tree.get_child_or_error( - "proto_index_directly_reroot_at_action").node - self.assertFalse(proto_index_node.is_repeated) - sparse_tensors = prensor_tree.get_sparse_tensors() - proto_index_directly_reroot_at_action = sparse_tensors[path.Path( - ["proto_index_directly_reroot_at_action"])] - dense_value = tf.sparse.to_dense( - proto_index_directly_reroot_at_action) - sparse_value = proto_index_directly_reroot_at_action - - self.assertAllEqual(sparse_value.values, [0, 0, 0, 2, 2]) - self.assertAllEqual(sparse_value.indices, [[0], [1], [2], [3], [4]]) - self.assertAllEqual(sparse_value.dense_shape, [5]) - self.assertAllEqual(dense_value, [0, 0, 0, 2, 2]) + """, + ] + expr = proto_test_util.text_to_expression(sessions, test_pb2.Session) + reroot_expr = expr.reroot("event.action") + # Reroot with a depth > 1 (all the other cases are depth == 1) + [prensor_tree] = calculate.calculate_prensors( + [reroot_expr.create_proto_index("proto_index_directly_reroot_at_action")] + ) + proto_index_node = prensor_tree.get_child_or_error( + "proto_index_directly_reroot_at_action" + ).node + self.assertFalse(proto_index_node.is_repeated) + sparse_tensors = prensor_tree.get_sparse_tensors() + proto_index_directly_reroot_at_action = sparse_tensors[ + path.Path(["proto_index_directly_reroot_at_action"]) + ] + dense_value = tf.sparse.to_dense(proto_index_directly_reroot_at_action) + sparse_value = proto_index_directly_reroot_at_action + + self.assertAllEqual(sparse_value.values, [0, 0, 0, 2, 2]) + self.assertAllEqual(sparse_value.indices, [[0], [1], [2], [3], [4]]) + self.assertAllEqual(sparse_value.dense_shape, [5]) + self.assertAllEqual(dense_value, [0, 0, 0, 2, 2]) def test_reroot_lenient_format(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor_with_lenient_field_names(), - validate_step_format=False, - ) - new_root = reroot.reroot(expr, path.Path(["doc"])) - self.assertFalse(new_root.validate_step_format) + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor_with_lenient_field_names(), + validate_step_format=False, + ) + new_root = reroot.reroot(expr, path.Path(["doc"])) + self.assertFalse(new_root.validate_step_format) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/size.py b/struct2tensor/expression_impl/size.py index 528483d..56cac9a 100644 --- a/struct2tensor/expression_impl/size.py +++ b/struct2tensor/expression_impl/size.py @@ -29,6 +29,7 @@ creates a new expression root that has an optional field "foo.bar_has", which is always present, and is true if there are one or more bar in foo. """ + from typing import Optional, Sequence, Tuple import tensorflow as tf @@ -37,125 +38,139 @@ from struct2tensor.expression_impl import map_values -def size_anonymous(root: expression.Expression, source_path: path.Path - ) -> Tuple[expression.Expression, path.Path]: - """Calculate the size of a field, and store it as an anonymous sibling. +def size_anonymous( + root: expression.Expression, source_path: path.Path +) -> Tuple[expression.Expression, path.Path]: + """Calculate the size of a field, and store it as an anonymous sibling. - Args: - root: the original expression. - source_path: the source path to measure. Cannot be root. + Args: + root: the original expression. + source_path: the source path to measure. Cannot be root. - Returns: - The new expression and the new field as a pair. - """ - return _size_impl(root, source_path, path.get_anonymous_field()) + Returns: + The new expression and the new field as a pair. + """ + return _size_impl(root, source_path, path.get_anonymous_field()) -def size(root: expression.Expression, source_path: path.Path, - new_field_name: path.Step) -> expression.Expression: - """Get the size of a field as a new sibling field. +def size( + root: expression.Expression, source_path: path.Path, new_field_name: path.Step +) -> expression.Expression: + """Get the size of a field as a new sibling field. - Args: - root: the original expression. - source_path: the source path to measure. Cannot be root. - new_field_name: the name of the sibling field. + Args: + root: the original expression. + source_path: the source path to measure. Cannot be root. + new_field_name: the name of the sibling field. - Returns: - The new expression. - """ - return _size_impl(root, source_path, new_field_name)[0] + Returns: + The new expression. + """ + return _size_impl(root, source_path, new_field_name)[0] -def has(root: expression.Expression, source_path: path.Path, - new_field_name: path.Step) -> expression.Expression: - """Get the has of a field as a new sibling field. +def has( + root: expression.Expression, source_path: path.Path, new_field_name: path.Step +) -> expression.Expression: + """Get the has of a field as a new sibling field. - Args: - root: the original expression. - source_path: the source path to measure. Cannot be root. - new_field_name: the name of the sibling field. + Args: + root: the original expression. + source_path: the source path to measure. Cannot be root. + new_field_name: the name of the sibling field. - Returns: - The new expression. - """ - new_root, size_p = size_anonymous(root, source_path) - # TODO(martinz): consider using copy_over to "remove" the size field - # from the result. - return map_values.map_values( - new_root, size_p, lambda x: tf.greater(x, tf.constant(0, dtype=tf.int64)), - tf.bool, new_field_name) + Returns: + The new expression. + """ + new_root, size_p = size_anonymous(root, source_path) + # TODO(martinz): consider using copy_over to "remove" the size field + # from the result. + return map_values.map_values( + new_root, + size_p, + lambda x: tf.greater(x, tf.constant(0, dtype=tf.int64)), + tf.bool, + new_field_name, + ) class SizeExpression(expression.Leaf): - """Size of the given expression. - - SizeExpression is intended to be a sibling of origin. - origin_parent should be the parent of origin. - - """ - - def __init__(self, origin: expression.Expression, - origin_parent: expression.Expression): - super().__init__(False, tf.int64) - self._origin = origin - self._origin_parent = origin_parent - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return [self._origin, self._origin_parent] - - def calculate( - self, - sources: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - - [origin_value, origin_parent_value] = sources - if not isinstance(origin_value, - (prensor.LeafNodeTensor, prensor.ChildNodeTensor)): - raise ValueError( - "origin_value must be a LeafNodeTensor or a ChildNodeTensor, " - "but was a " + str(type(origin_value))) - - if not isinstance(origin_parent_value, - (prensor.ChildNodeTensor, prensor.RootNodeTensor)): - raise ValueError("origin_parent_value must be a ChildNodeTensor " - "or a RootNodeTensor, but was a " + - str(type(origin_parent_value))) - - parent_index = origin_value.parent_index - num_parent_protos = origin_parent_value.size - # A vector of 1s of the same size as the parent_index. - updates = tf.ones(tf.shape(parent_index), dtype=tf.int64) - indices = tf.expand_dims(parent_index, 1) - # This is incrementing the size by 1 for each element. - # Obviously, not the fastest way to do this. - values = tf.scatter_nd(indices, updates, tf.reshape(num_parent_protos, [1])) - - # Need to create a new_parent_index = 0,1,2,3,4...n. - new_parent_index = tf.range(num_parent_protos, dtype=tf.int64) - return prensor.LeafNodeTensor(new_parent_index, values, False) - - def calculation_is_identity(self) -> bool: - return False - - def calculation_equal(self, expr: expression.Expression) -> bool: - return isinstance(expr, SizeExpression) + """Size of the given expression. + + SizeExpression is intended to be a sibling of origin. + origin_parent should be the parent of origin. + + """ + + def __init__( + self, origin: expression.Expression, origin_parent: expression.Expression + ): + super().__init__(False, tf.int64) + self._origin = origin + self._origin_parent = origin_parent + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return [self._origin, self._origin_parent] + + def calculate( + self, + sources: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + [origin_value, origin_parent_value] = sources + if not isinstance( + origin_value, (prensor.LeafNodeTensor, prensor.ChildNodeTensor) + ): + raise ValueError( + "origin_value must be a LeafNodeTensor or a ChildNodeTensor, " + "but was a " + str(type(origin_value)) + ) + + if not isinstance( + origin_parent_value, (prensor.ChildNodeTensor, prensor.RootNodeTensor) + ): + raise ValueError( + "origin_parent_value must be a ChildNodeTensor " + "or a RootNodeTensor, but was a " + str(type(origin_parent_value)) + ) + + parent_index = origin_value.parent_index + num_parent_protos = origin_parent_value.size + # A vector of 1s of the same size as the parent_index. + updates = tf.ones(tf.shape(parent_index), dtype=tf.int64) + indices = tf.expand_dims(parent_index, 1) + # This is incrementing the size by 1 for each element. + # Obviously, not the fastest way to do this. + values = tf.scatter_nd(indices, updates, tf.reshape(num_parent_protos, [1])) + + # Need to create a new_parent_index = 0,1,2,3,4...n. + new_parent_index = tf.range(num_parent_protos, dtype=tf.int64) + return prensor.LeafNodeTensor(new_parent_index, values, False) + + def calculation_is_identity(self) -> bool: + return False + + def calculation_equal(self, expr: expression.Expression) -> bool: + return isinstance(expr, SizeExpression) def _size_impl( - root: expression.Expression, source_path: path.Path, - new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: - if not source_path: - raise ValueError("Cannot get the size of the root.") - if root.get_descendant(source_path) is None: - raise ValueError(f"Path not found: {str(source_path)}") - parent_path = source_path.get_parent() - new_path = parent_path.get_child(new_field_name) - return expression_add.add_paths( - root, { - new_path: - SizeExpression( - root.get_descendant_or_error(source_path), - root.get_descendant_or_error(parent_path)) - }), new_path + root: expression.Expression, source_path: path.Path, new_field_name: path.Step +) -> Tuple[expression.Expression, path.Path]: + if not source_path: + raise ValueError("Cannot get the size of the root.") + if root.get_descendant(source_path) is None: + raise ValueError(f"Path not found: {str(source_path)}") + parent_path = source_path.get_parent() + new_path = parent_path.get_child(new_field_name) + return expression_add.add_paths( + root, + { + new_path: SizeExpression( + root.get_descendant_or_error(source_path), + root.get_descendant_or_error(parent_path), + ) + }, + ), new_path diff --git a/struct2tensor/expression_impl/size_test.py b/struct2tensor/expression_impl/size_test.py index a886b37..aadc63c 100644 --- a/struct2tensor/expression_impl/size_test.py +++ b/struct2tensor/expression_impl/size_test.py @@ -16,7 +16,7 @@ import tensorflow as tf from absl.testing import absltest from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor import create_expression, path @@ -26,43 +26,46 @@ @test_util.run_all_in_graph_and_eager_modes class SizeTest(tf.test.TestCase): + def test_size_anonymous(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + new_root, new_path = size.size_anonymous(expr, path.Path(["doc", "bar"])) + new_field = new_root.get_descendant_or_error(new_path) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) + self.assertAllEqual(leaf_node.values, [1, 2, 1]) - def test_size_anonymous(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - new_root, new_path = size.size_anonymous(expr, path.Path(["doc", "bar"])) - new_field = new_root.get_descendant_or_error(new_path) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) - self.assertAllEqual(leaf_node.values, [1, 2, 1]) + def test_size(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + new_root = size.size(expr, path.Path(["doc", "bar"]), "result") + new_field = new_root.get_descendant_or_error(path.Path(["doc", "result"])) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) + self.assertAllEqual(leaf_node.values, [1, 2, 1]) - def test_size(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - new_root = size.size(expr, path.Path(["doc", "bar"]), "result") - new_field = new_root.get_descendant_or_error(path.Path(["doc", "result"])) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) - self.assertAllEqual(leaf_node.values, [1, 2, 1]) + def test_size_missing_value(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + new_root = size.size(expr, path.Path(["doc", "keep_me"]), "result") + new_field = new_root.get_descendant_or_error(path.Path(["doc", "result"])) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) + self.assertAllEqual(leaf_node.values, [1, 1, 0]) - def test_size_missing_value(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - new_root = size.size(expr, path.Path(["doc", "keep_me"]), "result") - new_field = new_root.get_descendant_or_error(path.Path(["doc", "result"])) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) - self.assertAllEqual(leaf_node.values, [1, 1, 0]) - - def test_has(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - new_root = size.has(expr, path.Path(["doc", "keep_me"]), "result") - new_field = new_root.get_descendant_or_error(path.Path(["doc", "result"])) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) - self.assertAllEqual(leaf_node.values, [True, True, False]) + def test_has(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + new_root = size.has(expr, path.Path(["doc", "keep_me"]), "result") + new_field = new_root.get_descendant_or_error(path.Path(["doc", "result"])) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) + self.assertAllEqual(leaf_node.values, [True, True, False]) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_impl/slice_expression.py b/struct2tensor/expression_impl/slice_expression.py index 81507dc..f5321c5 100644 --- a/struct2tensor/expression_impl/slice_expression.py +++ b/struct2tensor/expression_impl/slice_expression.py @@ -122,145 +122,175 @@ IndexValue = expression.IndexValue -def slice_expression(expr: expression.Expression, p: path.Path, - new_field_name: path.Step, begin: Optional[IndexValue], - end: Optional[IndexValue]) -> expression.Expression: - """Creates a new subtree with a sliced expression. - - This follows the pattern of python slice() method. - See module-level comments for examples. - - Args: - expr: the original root expression - p: the path to the source to be sliced. - new_field_name: the name of the new subtree. - begin: beginning index - end: end index. - - Returns: - A new root expression. - """ - work_expr, mask_anonymous_path = _get_slice_mask(expr, p, begin, end) - work_expr = filter_expression.filter_by_sibling( - work_expr, p, mask_anonymous_path.field_list[-1], new_field_name) - new_path = p.get_parent().get_child(new_field_name) - # We created a lot of anonymous fields and intermediate expressions. Just grab - # the final result (and its children). - return expression_add.add_to(expr, {new_path: work_expr}) - - -def _get_mask(t: expression.Expression, p: path.Path, threshold: IndexValue, - relation: Callable[[tf.Tensor, IndexValue], tf.Tensor] - ) -> Tuple[expression.Expression, path.Path]: - """Gets a mask based on a relation of the index to a threshold. - - If the threshold is non-negative, then we create a mask that is true if - the relation(index, threshold) is true. - - If the threshold is negative, then we create a mask that is true if - the relation(size - index, threshold) is true. - - Args: - t: expression to add the field to. - p: path to create the mask for. - threshold: the cutoff threshold. - relation: tf.less or tf.greater_equal. - - Returns: - A boolean mask on the fields to keep on the model. - - Raises: - ValueError: if p is not in t. - """ - if t.get_descendant(p) is None: - raise ValueError(f"Path not found: {str(p)}") - work_expr, index_from_end = index.get_index_from_end( - t, p, path.get_anonymous_field()) - work_expr, mask_for_negative_threshold = map_values.map_values_anonymous( - work_expr, - index_from_end, lambda x: relation(x, tf.cast(threshold, tf.int64)), - tf.bool) - - work_expr, positional_index = index.get_positional_index( - work_expr, p, path.get_anonymous_field()) - work_expr, mask_for_non_negative_threshold = map_values.map_values_anonymous( - work_expr, - positional_index, lambda x: relation(x, tf.cast(threshold, tf.int64)), - tf.bool) - - if isinstance(threshold, int): - if threshold >= 0: - return work_expr, mask_for_non_negative_threshold - return work_expr, mask_for_negative_threshold - - def tf_cond_on_threshold(a, b): - return tf.cond(tf.greater_equal(threshold, 0), a, b) - - return map_values.map_many_values(work_expr, p.get_parent(), [ - x.field_list[-1] - for x in [mask_for_non_negative_threshold, mask_for_negative_threshold] - ], tf_cond_on_threshold, tf.bool, path.get_anonymous_field()) - - -def _get_begin_mask(expr: expression.Expression, p: path.Path, begin: IndexValue - ) -> Tuple[expression.Expression, path.Path]: - """Get a boolean mask of what indices to retain for slice given begin.""" - return _get_mask(expr, p, begin, tf.greater_equal) - - -def _get_end_mask(t: expression.Expression, p: path.Path, - end: IndexValue) -> Tuple[expression.Expression, path.Path]: - """Get a boolean mask of what indices to retain for slice given end.""" - return _get_mask(t, p, end, tf.less) +def slice_expression( + expr: expression.Expression, + p: path.Path, + new_field_name: path.Step, + begin: Optional[IndexValue], + end: Optional[IndexValue], +) -> expression.Expression: + """Creates a new subtree with a sliced expression. + + This follows the pattern of python slice() method. + See module-level comments for examples. + + Args: + expr: the original root expression + p: the path to the source to be sliced. + new_field_name: the name of the new subtree. + begin: beginning index + end: end index. + + Returns: + A new root expression. + """ + work_expr, mask_anonymous_path = _get_slice_mask(expr, p, begin, end) + work_expr = filter_expression.filter_by_sibling( + work_expr, p, mask_anonymous_path.field_list[-1], new_field_name + ) + new_path = p.get_parent().get_child(new_field_name) + # We created a lot of anonymous fields and intermediate expressions. Just grab + # the final result (and its children). + return expression_add.add_to(expr, {new_path: work_expr}) + + +def _get_mask( + t: expression.Expression, + p: path.Path, + threshold: IndexValue, + relation: Callable[[tf.Tensor, IndexValue], tf.Tensor], +) -> Tuple[expression.Expression, path.Path]: + """Gets a mask based on a relation of the index to a threshold. + + If the threshold is non-negative, then we create a mask that is true if + the relation(index, threshold) is true. + + If the threshold is negative, then we create a mask that is true if + the relation(size - index, threshold) is true. + + Args: + t: expression to add the field to. + p: path to create the mask for. + threshold: the cutoff threshold. + relation: tf.less or tf.greater_equal. + + Returns: + A boolean mask on the fields to keep on the model. + + Raises: + ValueError: if p is not in t. + """ + if t.get_descendant(p) is None: + raise ValueError(f"Path not found: {str(p)}") + work_expr, index_from_end = index.get_index_from_end( + t, p, path.get_anonymous_field() + ) + work_expr, mask_for_negative_threshold = map_values.map_values_anonymous( + work_expr, + index_from_end, + lambda x: relation(x, tf.cast(threshold, tf.int64)), + tf.bool, + ) + + work_expr, positional_index = index.get_positional_index( + work_expr, p, path.get_anonymous_field() + ) + work_expr, mask_for_non_negative_threshold = map_values.map_values_anonymous( + work_expr, + positional_index, + lambda x: relation(x, tf.cast(threshold, tf.int64)), + tf.bool, + ) + + if isinstance(threshold, int): + if threshold >= 0: + return work_expr, mask_for_non_negative_threshold + return work_expr, mask_for_negative_threshold + + def tf_cond_on_threshold(a, b): + return tf.cond(tf.greater_equal(threshold, 0), a, b) + + return map_values.map_many_values( + work_expr, + p.get_parent(), + [ + x.field_list[-1] + for x in [mask_for_non_negative_threshold, mask_for_negative_threshold] + ], + tf_cond_on_threshold, + tf.bool, + path.get_anonymous_field(), + ) + + +def _get_begin_mask( + expr: expression.Expression, p: path.Path, begin: IndexValue +) -> Tuple[expression.Expression, path.Path]: + """Get a boolean mask of what indices to retain for slice given begin.""" + return _get_mask(expr, p, begin, tf.greater_equal) + + +def _get_end_mask( + t: expression.Expression, p: path.Path, end: IndexValue +) -> Tuple[expression.Expression, path.Path]: + """Get a boolean mask of what indices to retain for slice given end.""" + return _get_mask(t, p, end, tf.less) def _get_slice_mask( - expr: expression.Expression, p: path.Path, begin: Optional[IndexValue], - end: Optional[IndexValue]) -> Tuple[expression.Expression, path.Path]: - """Gets a mask for slicing a path. - - One way to consider the elements of a path "foo.bar" is as a list of list of - list of elements. Slicing a path slices this doubly nested list of elements, - based upon positions in its parent list. Each parent list has a size, and - there is a beginning and end relative to the elements in that list. - - At each path p, there is conceptually a list of...list of elements. - - For example, given: - an index with respect to its parent - The range is specified with beginning and an end. - 1. If begin is not present, begin_index is implied to be zero. - 2. If begin is negative, begin_index is the size of a particular - list + begin - 3. If end is not present, end_index is the length of the list + 1. - 4. If end is negative, end_index is the length of the list + end - 5. If end is non-negative, end_index is end. - The mask is positive for all elements in range(begin_index, end_index), and - negative elsewhere. - - The mask returned is a sibling of path p, where for every element in p, there - is a corresponding element in the mask. - - Args: - expr: the root expression - p: the path to be sliced - begin: the beginning index - end: the ending index - - Returns: - An expression,path pair, where the expression contains all the children in - `expr` and an anonymous field of the mask and the path points to - the mask field. - """ - if begin is None: + expr: expression.Expression, + p: path.Path, + begin: Optional[IndexValue], + end: Optional[IndexValue], +) -> Tuple[expression.Expression, path.Path]: + """Gets a mask for slicing a path. + + One way to consider the elements of a path "foo.bar" is as a list of list of + list of elements. Slicing a path slices this doubly nested list of elements, + based upon positions in its parent list. Each parent list has a size, and + there is a beginning and end relative to the elements in that list. + + At each path p, there is conceptually a list of...list of elements. + + For example, given: + an index with respect to its parent + The range is specified with beginning and an end. + 1. If begin is not present, begin_index is implied to be zero. + 2. If begin is negative, begin_index is the size of a particular + list + begin + 3. If end is not present, end_index is the length of the list + 1. + 4. If end is negative, end_index is the length of the list + end + 5. If end is non-negative, end_index is end. + The mask is positive for all elements in range(begin_index, end_index), and + negative elsewhere. + + The mask returned is a sibling of path p, where for every element in p, there + is a corresponding element in the mask. + + Args: + expr: the root expression + p: the path to be sliced + begin: the beginning index + end: the ending index + + Returns: + An expression,path pair, where the expression contains all the children in + `expr` and an anonymous field of the mask and the path points to + the mask field. + """ + if begin is None: + if end is None: + raise ValueError("Must specify begin or end.") + return _get_end_mask(expr, p, end) if end is None: - raise ValueError("Must specify begin or end.") - return _get_end_mask(expr, p, end) - if end is None: - return _get_begin_mask(expr, p, begin) - work_expr, begin_mask = _get_begin_mask(expr, p, begin) - work_expr, end_mask = _get_end_mask(work_expr, p, end) - return map_values.map_many_values( - work_expr, p.get_parent(), - [x.field_list[-1] for x in [begin_mask, end_mask]], tf.logical_and, - tf.bool, path.get_anonymous_field()) + return _get_begin_mask(expr, p, begin) + work_expr, begin_mask = _get_begin_mask(expr, p, begin) + work_expr, end_mask = _get_end_mask(work_expr, p, end) + return map_values.map_many_values( + work_expr, + p.get_parent(), + [x.field_list[-1] for x in [begin_mask, end_mask]], + tf.logical_and, + tf.bool, + path.get_anonymous_field(), + ) diff --git a/struct2tensor/expression_impl/slice_expression_test.py b/struct2tensor/expression_impl/slice_expression_test.py index 9f87b1e..b282a2b 100644 --- a/struct2tensor/expression_impl/slice_expression_test.py +++ b/struct2tensor/expression_impl/slice_expression_test.py @@ -16,14 +16,14 @@ import tensorflow as tf from absl.testing import absltest from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) # For tf.Session.Run against a Prensor from struct2tensor import ( - calculate, - create_expression, - path, # pylint: disable=unused-import + calculate, + create_expression, + path, # pylint: disable=unused-import ) from struct2tensor.expression_impl import slice_expression from struct2tensor.test import prensor_test_util @@ -31,213 +31,242 @@ @test_util.run_all_in_graph_and_eager_modes class SliceExpressionTest(tf.test.TestCase): + def test_slice_end(self): + root = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + root_2 = slice_expression.slice_expression( + root, path.Path(["doc"]), "new_doc", None, 1 + ) + result = calculate.calculate_prensors([root_2])[0] - def test_slice_end(self): - root = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - root_2 = slice_expression.slice_expression(root, path.Path(["doc"]), - "new_doc", None, 1) - result = calculate.calculate_prensors([root_2])[0] - - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc" - ])).node.parent_index, [0, 1]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", "keep_me" - ])).node.parent_index, [0, 1]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", - "keep_me"])).node.values, - [False, True]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", - "bar"])).node.parent_index, - [0, 1, 1]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", - "bar"])).node.values, - [b"a", b"b", b"c"]) - - def test_slice_begin(self): - """Test slice with only begin specified. - - Starts with: - { - foo:9, - foorepeated:[9], - doc:[{ - bar:["a"], - keep_me:False - }], - user:[ + self.assertAllEqual( + result.get_descendant_or_error(path.Path(["new_doc"])).node.parent_index, + [0, 1], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new_doc", "keep_me"]) + ).node.parent_index, + [0, 1], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new_doc", "keep_me"]) + ).node.values, + [False, True], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new_doc", "bar"]) + ).node.parent_index, + [0, 1, 1], + ) + self.assertAllEqual( + result.get_descendant_or_error(path.Path(["new_doc", "bar"])).node.values, + [b"a", b"b", b"c"], + ) + + def test_slice_begin(self): + """Test slice with only begin specified. + + Starts with: { - friends:["a"] - }] - } - {foo:8, - foorepeated:[8,7], - doc:[{ - bar:["b","c"], - keep_me:True - },{ - bar:["d"] - }], - user:[{ - friends:["b", "c"] - },{ - friends:["d"] - }], - } - {foo:7, - foorepeated:[6], - user:[{friends:["e"]}]} - - Creates new_doc by slicing doc[1:]: - {foo:9, - foorepeated:[9], - doc:[{ - bar:["a"], - keep_me:False - }], - user:[{ - friends:["a"] - }]} - {foo:8, - foorepeated:[8,7], - doc:[{ - bar:["b","c"], - keep_me:True - },{ - bar:["d"] - }], - new_doc[{ - bar:["d"] - }], - user:[{ - friends:["b", "c"] - },{ - friends:["d"]}],} - {foo:7, - foorepeated:[6], - user:[{ - friends:["e"] - }]} - """ - root = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - root_2 = slice_expression.slice_expression(root, path.Path(["doc"]), - "new_doc", 1, None) - result = calculate.calculate_prensors([root_2])[0] - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc" - ])).node.parent_index, [1]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", "keep_me" - ])).node.parent_index, []) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", - "keep_me"])).node.values, []) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", - "bar"])).node.parent_index, - [0]) - self.assertAllEqual( - result.get_descendant_or_error(path.Path(["new_doc", - "bar"])).node.values, [b"d"]) - - def test_slice_mask(self): - root = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - root_2, new_path = slice_expression._get_slice_mask(root, - path.Path(["doc"]), - None, 1) - result = calculate.calculate_prensors([root_2])[0] - self.assertAllEqual( - result.get_descendant_or_error(new_path).node.parent_index, [0, 1, 1]) - self.assertAllEqual( - result.get_descendant_or_error(new_path).node.values, - [True, True, False]) - - def test_slice_mask_end_negative(self): - root = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - root_2, new_path = slice_expression._get_slice_mask(root, - path.Path(["doc"]), - None, -1) - result = calculate.calculate_prensors([root_2])[0] - self.assertAllEqual( - result.get_descendant_or_error(new_path).node.parent_index, [0, 1, 1]) - self.assertAllEqual( - result.get_descendant_or_error(new_path).node.values, - [False, True, False]) - - def test_slice_mask_begin_positive(self): - root = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - root_2, new_path = slice_expression._get_slice_mask(root, - path.Path(["doc"]), 1, - None) - [result] = calculate.calculate_prensors([root_2]) - self.assertAllEqual( - result.get_descendant_or_error(new_path).node.parent_index, [0, 1, 1]) - self.assertAllEqual( - result.get_descendant_or_error(new_path).node.values, - [False, False, True]) - - def test_slice_mask_begin_negative(self): - root = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - root_2, new_path = slice_expression._get_slice_mask(root, - path.Path(["doc"]), -1, - None) - result = calculate.calculate_prensors([root_2])[0] - self.assertAllEqual( - result.get_descendant_or_error(new_path).node.parent_index, [0, 1, 1]) - self.assertAllEqual( - result.get_descendant_or_error(new_path).node.values, - [True, False, True]) - - def test_slice_end_lenient_formatting(self): - root = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor_with_lenient_field_names(), - validate_step_format=False, - ) - root_2 = slice_expression.slice_expression( - root, path.Path(["doc"], validate_step_format=False), "new/doc", None, 1 - ) - result = calculate.calculate_prensors([root_2])[0] - - self.assertAllEqual( - result.get_descendant_or_error( - path.Path(["new/doc"], validate_step_format=False) - ).node.parent_index, - [0, 1], - ) - self.assertAllEqual( - result.get_descendant_or_error( - path.Path(["new/doc", "keep_me/x"], validate_step_format=False) - ).node.parent_index, - [0, 1], - ) - self.assertAllEqual( - result.get_descendant_or_error( - path.Path(["new/doc", "keep_me/x"], validate_step_format=False) - ).node.values, - [False, True], - ) - self.assertAllEqual( - result.get_descendant_or_error( - path.Path(["new/doc", "bar.baz"], validate_step_format=False) - ).node.parent_index, - [0, 1, 1], - ) - self.assertAllEqual( - result.get_descendant_or_error( - path.Path(["new/doc", "bar.baz"], validate_step_format=False) - ).node.values, - [1, 2, 3], - ) + foo:9, + foorepeated:[9], + doc:[{ + bar:["a"], + keep_me:False + }], + user:[ + { + friends:["a"] + }] + } + {foo:8, + foorepeated:[8,7], + doc:[{ + bar:["b","c"], + keep_me:True + },{ + bar:["d"] + }], + user:[{ + friends:["b", "c"] + },{ + friends:["d"] + }], + } + {foo:7, + foorepeated:[6], + user:[{friends:["e"]}]} + + Creates new_doc by slicing doc[1:]: + {foo:9, + foorepeated:[9], + doc:[{ + bar:["a"], + keep_me:False + }], + user:[{ + friends:["a"] + }]} + {foo:8, + foorepeated:[8,7], + doc:[{ + bar:["b","c"], + keep_me:True + },{ + bar:["d"] + }], + new_doc[{ + bar:["d"] + }], + user:[{ + friends:["b", "c"] + },{ + friends:["d"]}],} + {foo:7, + foorepeated:[6], + user:[{ + friends:["e"] + }]} + """ + root = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + root_2 = slice_expression.slice_expression( + root, path.Path(["doc"]), "new_doc", 1, None + ) + result = calculate.calculate_prensors([root_2])[0] + self.assertAllEqual( + result.get_descendant_or_error(path.Path(["new_doc"])).node.parent_index, + [1], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new_doc", "keep_me"]) + ).node.parent_index, + [], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new_doc", "keep_me"]) + ).node.values, + [], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new_doc", "bar"]) + ).node.parent_index, + [0], + ) + self.assertAllEqual( + result.get_descendant_or_error(path.Path(["new_doc", "bar"])).node.values, + [b"d"], + ) + + def test_slice_mask(self): + root = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + root_2, new_path = slice_expression._get_slice_mask( + root, path.Path(["doc"]), None, 1 + ) + result = calculate.calculate_prensors([root_2])[0] + self.assertAllEqual( + result.get_descendant_or_error(new_path).node.parent_index, [0, 1, 1] + ) + self.assertAllEqual( + result.get_descendant_or_error(new_path).node.values, [True, True, False] + ) + + def test_slice_mask_end_negative(self): + root = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + root_2, new_path = slice_expression._get_slice_mask( + root, path.Path(["doc"]), None, -1 + ) + result = calculate.calculate_prensors([root_2])[0] + self.assertAllEqual( + result.get_descendant_or_error(new_path).node.parent_index, [0, 1, 1] + ) + self.assertAllEqual( + result.get_descendant_or_error(new_path).node.values, [False, True, False] + ) + + def test_slice_mask_begin_positive(self): + root = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + root_2, new_path = slice_expression._get_slice_mask( + root, path.Path(["doc"]), 1, None + ) + [result] = calculate.calculate_prensors([root_2]) + self.assertAllEqual( + result.get_descendant_or_error(new_path).node.parent_index, [0, 1, 1] + ) + self.assertAllEqual( + result.get_descendant_or_error(new_path).node.values, [False, False, True] + ) + + def test_slice_mask_begin_negative(self): + root = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + root_2, new_path = slice_expression._get_slice_mask( + root, path.Path(["doc"]), -1, None + ) + result = calculate.calculate_prensors([root_2])[0] + self.assertAllEqual( + result.get_descendant_or_error(new_path).node.parent_index, [0, 1, 1] + ) + self.assertAllEqual( + result.get_descendant_or_error(new_path).node.values, [True, False, True] + ) + + def test_slice_end_lenient_formatting(self): + root = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor_with_lenient_field_names(), + validate_step_format=False, + ) + root_2 = slice_expression.slice_expression( + root, path.Path(["doc"], validate_step_format=False), "new/doc", None, 1 + ) + result = calculate.calculate_prensors([root_2])[0] + + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new/doc"], validate_step_format=False) + ).node.parent_index, + [0, 1], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new/doc", "keep_me/x"], validate_step_format=False) + ).node.parent_index, + [0, 1], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new/doc", "keep_me/x"], validate_step_format=False) + ).node.values, + [False, True], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new/doc", "bar.baz"], validate_step_format=False) + ).node.parent_index, + [0, 1, 1], + ) + self.assertAllEqual( + result.get_descendant_or_error( + path.Path(["new/doc", "bar.baz"], validate_step_format=False) + ).node.values, + [1, 2, 3], + ) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/expression_test.py b/struct2tensor/expression_test.py index 0662133..6a53686 100644 --- a/struct2tensor/expression_test.py +++ b/struct2tensor/expression_test.py @@ -23,7 +23,7 @@ import tensorflow as tf from absl.testing import absltest from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) from tensorflow_metadata.proto.v0 import schema_pb2 @@ -32,232 +32,248 @@ def _features_as_map(feature_list): - return {feature.name: feature for feature in feature_list} + return {feature.name: feature for feature in feature_list} class ExpressionTest(absltest.TestCase): - - def test_promote(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor()) - new_root = expr.promote("user.friends", "new_field") - new_field = new_root.get_child_or_error("new_field") - self.assertIsNotNone(new_field) - self.assertTrue(new_field.is_repeated) - self.assertEqual(new_field.type, tf.string) - self.assertTrue(new_field.is_leaf) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertEqual(leaf_node.values.dtype, tf.string) - self.assertEqual(new_field.known_field_names(), frozenset()) - - def test_broadcast(self): - """Tests broadcast.broadcast(...), and indirectly tests set_path.""" - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - new_root = expr.broadcast("foo", "user", "new_field") - new_field = new_root.get_child("user").get_child("new_field") - self.assertIsNotNone(new_field) - self.assertFalse(new_field.is_repeated) - self.assertEqual(new_field.type, tf.int32) - self.assertTrue(new_field.is_leaf) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertEqual(leaf_node.values.dtype, tf.int32) - self.assertEqual(new_field.known_field_names(), frozenset()) - - def test_project(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_nested_prensor()) - projected = expr.project( - [path.Path(["user", "friends"]), - path.Path(["doc", "keep_me"])]) - self.assertIsNotNone( - projected.get_descendant(path.Path(["user", "friends"]))) - self.assertIsNotNone( - projected.get_descendant(path.Path(["doc", "keep_me"]))) - self.assertIsNone(projected.get_descendant(path.Path(["doc", "bar"]))) - - def test_promote_and_broadcast_test(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - new_root = expr.promote_and_broadcast({"new_field": "user.friends"}, - path.Path(["doc"])) - new_field = new_root.get_descendant_or_error( - path.Path(["doc", "new_field"])) - self.assertTrue(new_field.is_repeated) - self.assertEqual(new_field.type, tf.string) - self.assertTrue(new_field.is_leaf) - self.assertTrue(new_field.calculation_equal(new_field)) - self.assertFalse(new_field.calculation_equal(expr)) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertEqual(leaf_node.values.dtype, tf.string) - self.assertEqual(new_field.known_field_names(), frozenset()) - - def test_get_schema(self): - foo_feature = schema_pb2.Feature() - foo_feature.int_domain.max = 10 - foo = expression_test_util.MockExpression( - is_repeated=False, my_type=tf.int64, schema_feature=foo_feature) - foorepeated = expression_test_util.MockExpression( - is_repeated=True, my_type=tf.int64) - bar_feature = schema_pb2.Feature() - bar_feature.presence.min_count = 17 - bar = expression_test_util.MockExpression( - is_repeated=True, my_type=tf.string, schema_feature=bar_feature) - keep_me = expression_test_util.MockExpression( - is_repeated=False, my_type=tf.bool) - - doc = expression_test_util.MockExpression( - is_repeated=True, my_type=tf.int64, children={"bar": bar, - "keep_me": keep_me}) - root = expression_test_util.MockExpression( - is_repeated=True, - my_type=None, - children={ - "foo": foo, - "foorepeated": foorepeated, - "doc": doc - }) - - schema_result = root.get_schema() - feature_map = _features_as_map(schema_result.feature) - self.assertIn("foo", feature_map) - # Check the properties of a first-level feature. - self.assertEqual(feature_map["foo"].int_domain.max, 10) - self.assertIn("foorepeated", feature_map) - - self.assertEqual(feature_map["doc"].type, schema_pb2.FeatureType.STRUCT) - doc_feature_map = _features_as_map(feature_map["doc"].struct_domain.feature) - # Test that second level features are correctly handled. - self.assertIn("bar", doc_feature_map) - # Test that an string_domain specified at the schema level is inserted - # correctly. - self.assertEqual(doc_feature_map["bar"].presence.min_count, 17) - self.assertIn("keep_me", doc_feature_map) - - def test_get_schema_missing_features(self): - # The expr has a number of features: foo, foorepeated, doc, user. - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - # The schema has only a subset of the features on the expr. - schema = schema_pb2.Schema() - feature = schema.feature.add() - feature.name = "foo" - feature.type = schema_pb2.FeatureType.INT - feature.value_count.min = 1 - feature.value_count.max = 1 - feature = schema.feature.add() - feature.name = "foorepeated" - feature.type = schema_pb2.FeatureType.INT - feature.value_count.min = 0 - feature.value_count.max = 5 - feature = schema.feature.add() - feature.name = "doc" - feature.type = schema_pb2.FeatureType.STRUCT - feature.struct_domain.feature.append( - schema_pb2.Feature(name="keep_me", type=schema_pb2.FeatureType.INT)) - - # By default, the output schema has all features present in the expr. - expr = expr.apply_schema(schema) - output_schema = expr.get_schema() - self.assertNotEqual(schema, output_schema) - self.assertLen(schema.feature, 3) - self.assertLen(output_schema.feature, 4) - - # With create_schema_features = False, only features on the original schema - # propagate to the new schema. - output_schema = expr.get_schema(create_schema_features=False) - self.assertLen(output_schema.feature, 3) - - def test_get_schema_empty_schema(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - schema = schema_pb2.Schema() - # By default, the output schema has all features present in the proto. - expr = expr.apply_schema(schema) - output_schema = expr.get_schema() - self.assertEmpty(schema.feature) - self.assertLen(output_schema.feature, 4) - - # With create_schema_features = False, the schema will be empty. - output_schema = expr.get_schema(create_schema_features=False) - self.assertEmpty(output_schema.feature) - self.assertEqual(schema, output_schema) - - def test_get_schema_no_schema(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - output_schema = expr.get_schema() - self.assertLen(output_schema.feature, 4) - - # With create_schema_features = False, the schema will be empty. - output_schema = expr.get_schema(create_schema_features=False) - self.assertEmpty(output_schema.feature) - self.assertEqual(schema_pb2.Schema(), output_schema) + def test_promote(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor() + ) + new_root = expr.promote("user.friends", "new_field") + new_field = new_root.get_child_or_error("new_field") + self.assertIsNotNone(new_field) + self.assertTrue(new_field.is_repeated) + self.assertEqual(new_field.type, tf.string) + self.assertTrue(new_field.is_leaf) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertEqual(leaf_node.values.dtype, tf.string) + self.assertEqual(new_field.known_field_names(), frozenset()) + + def test_broadcast(self): + """Tests broadcast.broadcast(...), and indirectly tests set_path.""" + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + new_root = expr.broadcast("foo", "user", "new_field") + new_field = new_root.get_child("user").get_child("new_field") + self.assertIsNotNone(new_field) + self.assertFalse(new_field.is_repeated) + self.assertEqual(new_field.type, tf.int32) + self.assertTrue(new_field.is_leaf) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertEqual(leaf_node.values.dtype, tf.int32) + self.assertEqual(new_field.known_field_names(), frozenset()) + + def test_project(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_nested_prensor() + ) + projected = expr.project( + [path.Path(["user", "friends"]), path.Path(["doc", "keep_me"])] + ) + self.assertIsNotNone(projected.get_descendant(path.Path(["user", "friends"]))) + self.assertIsNotNone(projected.get_descendant(path.Path(["doc", "keep_me"]))) + self.assertIsNone(projected.get_descendant(path.Path(["doc", "bar"]))) + + def test_promote_and_broadcast_test(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + new_root = expr.promote_and_broadcast( + {"new_field": "user.friends"}, path.Path(["doc"]) + ) + new_field = new_root.get_descendant_or_error(path.Path(["doc", "new_field"])) + self.assertTrue(new_field.is_repeated) + self.assertEqual(new_field.type, tf.string) + self.assertTrue(new_field.is_leaf) + self.assertTrue(new_field.calculation_equal(new_field)) + self.assertFalse(new_field.calculation_equal(expr)) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertEqual(leaf_node.values.dtype, tf.string) + self.assertEqual(new_field.known_field_names(), frozenset()) + + def test_get_schema(self): + foo_feature = schema_pb2.Feature() + foo_feature.int_domain.max = 10 + foo = expression_test_util.MockExpression( + is_repeated=False, my_type=tf.int64, schema_feature=foo_feature + ) + foorepeated = expression_test_util.MockExpression( + is_repeated=True, my_type=tf.int64 + ) + bar_feature = schema_pb2.Feature() + bar_feature.presence.min_count = 17 + bar = expression_test_util.MockExpression( + is_repeated=True, my_type=tf.string, schema_feature=bar_feature + ) + keep_me = expression_test_util.MockExpression( + is_repeated=False, my_type=tf.bool + ) + + doc = expression_test_util.MockExpression( + is_repeated=True, + my_type=tf.int64, + children={"bar": bar, "keep_me": keep_me}, + ) + root = expression_test_util.MockExpression( + is_repeated=True, + my_type=None, + children={"foo": foo, "foorepeated": foorepeated, "doc": doc}, + ) + + schema_result = root.get_schema() + feature_map = _features_as_map(schema_result.feature) + self.assertIn("foo", feature_map) + # Check the properties of a first-level feature. + self.assertEqual(feature_map["foo"].int_domain.max, 10) + self.assertIn("foorepeated", feature_map) + + self.assertEqual(feature_map["doc"].type, schema_pb2.FeatureType.STRUCT) + doc_feature_map = _features_as_map(feature_map["doc"].struct_domain.feature) + # Test that second level features are correctly handled. + self.assertIn("bar", doc_feature_map) + # Test that an string_domain specified at the schema level is inserted + # correctly. + self.assertEqual(doc_feature_map["bar"].presence.min_count, 17) + self.assertIn("keep_me", doc_feature_map) + + def test_get_schema_missing_features(self): + # The expr has a number of features: foo, foorepeated, doc, user. + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + # The schema has only a subset of the features on the expr. + schema = schema_pb2.Schema() + feature = schema.feature.add() + feature.name = "foo" + feature.type = schema_pb2.FeatureType.INT + feature.value_count.min = 1 + feature.value_count.max = 1 + feature = schema.feature.add() + feature.name = "foorepeated" + feature.type = schema_pb2.FeatureType.INT + feature.value_count.min = 0 + feature.value_count.max = 5 + feature = schema.feature.add() + feature.name = "doc" + feature.type = schema_pb2.FeatureType.STRUCT + feature.struct_domain.feature.append( + schema_pb2.Feature(name="keep_me", type=schema_pb2.FeatureType.INT) + ) + + # By default, the output schema has all features present in the expr. + expr = expr.apply_schema(schema) + output_schema = expr.get_schema() + self.assertNotEqual(schema, output_schema) + self.assertLen(schema.feature, 3) + self.assertLen(output_schema.feature, 4) + + # With create_schema_features = False, only features on the original schema + # propagate to the new schema. + output_schema = expr.get_schema(create_schema_features=False) + self.assertLen(output_schema.feature, 3) + + def test_get_schema_empty_schema(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + schema = schema_pb2.Schema() + # By default, the output schema has all features present in the proto. + expr = expr.apply_schema(schema) + output_schema = expr.get_schema() + self.assertEmpty(schema.feature) + self.assertLen(output_schema.feature, 4) + + # With create_schema_features = False, the schema will be empty. + output_schema = expr.get_schema(create_schema_features=False) + self.assertEmpty(output_schema.feature) + self.assertEqual(schema, output_schema) + + def test_get_schema_no_schema(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + output_schema = expr.get_schema() + self.assertLen(output_schema.feature, 4) + + # With create_schema_features = False, the schema will be empty. + output_schema = expr.get_schema(create_schema_features=False) + self.assertEmpty(output_schema.feature) + self.assertEqual(schema_pb2.Schema(), output_schema) @test_util.run_all_in_graph_and_eager_modes class ExpressionValuesTest(tf.test.TestCase): - - def test_map_field_values_test(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_simple_prensor()) - - new_root = expr.map_field_values("foo", lambda x: x * 2, tf.int64, - "foo_doubled") - - leaf_node = expression_test_util.calculate_value_slowly( - new_root.get_descendant_or_error(path.Path(["foo_doubled"]))) - - self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) - self.assertAllEqual(leaf_node.values, [18, 16, 14]) - - def test_create_size_field(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - new_root = expr.create_size_field("doc.bar", "result") - new_field = new_root.get_descendant_or_error(path.Path(["doc", "result"])) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) - self.assertAllEqual(leaf_node.values, [1, 2, 1]) - - def test_create_has_field(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()) - new_root = expr.create_has_field("doc.keep_me", "result") - new_field = new_root.get_descendant_or_error(path.Path(["doc", "result"])) - leaf_node = expression_test_util.calculate_value_slowly(new_field) - self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) - self.assertAllEqual(leaf_node.values, [True, True, False]) - - def test_reroot_and_create_proto_index(self): - expr = create_expression.create_expression_from_prensor( - prensor_test_util.create_big_prensor()).reroot( - "doc").create_proto_index("proto_index") - proto_index = expr.get_child("proto_index") - new_field = expr.get_child("bar") - leaf_node = expression_test_util.calculate_value_slowly(new_field) - proto_index_node = expression_test_util.calculate_value_slowly(proto_index) - - self.assertIsNotNone(new_field) - self.assertTrue(new_field.is_repeated) - self.assertEqual(new_field.type, tf.string) - self.assertTrue(new_field.is_leaf) - self.assertEqual(new_field.known_field_names(), frozenset()) - self.assertEqual(leaf_node.values.dtype, tf.string) - - self.assertIsNotNone(proto_index) - self.assertFalse(proto_index.is_repeated) - self.assertEqual(proto_index.type, tf.int64) - self.assertTrue(proto_index.is_leaf) - self.assertEqual(proto_index.known_field_names(), frozenset()) - - self.assertEqual(proto_index_node.values.dtype, tf.int64) - - self.assertAllEqual([b"a", b"b", b"c", b"d"], leaf_node.values) - self.assertAllEqual([0, 1, 1, 2], leaf_node.parent_index) - self.assertAllEqual([0, 1, 1], proto_index_node.values) - self.assertAllEqual([0, 1, 2], proto_index_node.parent_index) + def test_map_field_values_test(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_simple_prensor() + ) + + new_root = expr.map_field_values( + "foo", lambda x: x * 2, tf.int64, "foo_doubled" + ) + + leaf_node = expression_test_util.calculate_value_slowly( + new_root.get_descendant_or_error(path.Path(["foo_doubled"])) + ) + + self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) + self.assertAllEqual(leaf_node.values, [18, 16, 14]) + + def test_create_size_field(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + new_root = expr.create_size_field("doc.bar", "result") + new_field = new_root.get_descendant_or_error(path.Path(["doc", "result"])) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) + self.assertAllEqual(leaf_node.values, [1, 2, 1]) + + def test_create_has_field(self): + expr = create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + new_root = expr.create_has_field("doc.keep_me", "result") + new_field = new_root.get_descendant_or_error(path.Path(["doc", "result"])) + leaf_node = expression_test_util.calculate_value_slowly(new_field) + self.assertAllEqual(leaf_node.parent_index, [0, 1, 2]) + self.assertAllEqual(leaf_node.values, [True, True, False]) + + def test_reroot_and_create_proto_index(self): + expr = ( + create_expression.create_expression_from_prensor( + prensor_test_util.create_big_prensor() + ) + .reroot("doc") + .create_proto_index("proto_index") + ) + proto_index = expr.get_child("proto_index") + new_field = expr.get_child("bar") + leaf_node = expression_test_util.calculate_value_slowly(new_field) + proto_index_node = expression_test_util.calculate_value_slowly(proto_index) + + self.assertIsNotNone(new_field) + self.assertTrue(new_field.is_repeated) + self.assertEqual(new_field.type, tf.string) + self.assertTrue(new_field.is_leaf) + self.assertEqual(new_field.known_field_names(), frozenset()) + self.assertEqual(leaf_node.values.dtype, tf.string) + + self.assertIsNotNone(proto_index) + self.assertFalse(proto_index.is_repeated) + self.assertEqual(proto_index.type, tf.int64) + self.assertTrue(proto_index.is_leaf) + self.assertEqual(proto_index.known_field_names(), frozenset()) + + self.assertEqual(proto_index_node.values.dtype, tf.int64) + + self.assertAllEqual([b"a", b"b", b"c", b"d"], leaf_node.values) + self.assertAllEqual([0, 1, 1, 2], leaf_node.parent_index) + self.assertAllEqual([0, 1, 1], proto_index_node.values) + self.assertAllEqual([0, 1, 2], proto_index_node.parent_index) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/ops/file_descriptor_set.py b/struct2tensor/ops/file_descriptor_set.py index db3f4b4..b3af30b 100644 --- a/struct2tensor/ops/file_descriptor_set.py +++ b/struct2tensor/ops/file_descriptor_set.py @@ -20,6 +20,7 @@ upon later ones (fails on circular dependencies but that won't happen). 4. It serializes the file descriptors into a FileDescriptorSet. """ + from typing import Sequence, Set from google.protobuf import descriptor, descriptor_pb2 @@ -27,111 +28,120 @@ from struct2tensor import path -def _are_dependencies_handled(file_descriptor: descriptor.FileDescriptor, - dependencies: Set[descriptor.Descriptor]) -> bool: - """Returns True iff dependencies of descriptor are in dependencies.""" - return all(dependency in dependencies for dependency in file_descriptor.dependencies) - - -def _order_dependencies(file_descriptors: Set[descriptor.FileDescriptor] - ) -> Sequence[descriptor.FileDescriptor]: - """Given a set of file descriptors, return them as an ordered list. - - Each file descriptor in the resulting list follows its dependencies. - - Args: - file_descriptors: a list of file descriptors. - - Returns: - file descriptors as an ordered list. - - Raises: - ValueError: if there is a circular dependency. - """ - # TODO(martinz): if you first create a graph, then you can calculate this in - # linear (instead of quadratic) time. However, let's hold off on that - # complexity for now. - result = [] - result_set = set() - progress = True - # We sort the file descriptors so that the list of descriptors returned - # is deterministic and the attr input to the DecodeProto* ops are stable. - remaining_file_descriptors = sorted(list(file_descriptors), - key=lambda x: x.name) - while remaining_file_descriptors and progress: - failed = [] - progress = False - while remaining_file_descriptors: - file_descriptor = remaining_file_descriptors.pop() - if _are_dependencies_handled(file_descriptor, result_set): - result.append(file_descriptor) - result_set.add(file_descriptor) - progress = True - else: - failed.append(file_descriptor) - remaining_file_descriptors = failed - if remaining_file_descriptors: - # In theory, the module should not have even compiled. - raise ValueError('Circular dependency ' + str( - [file_desc.name for file_desc in remaining_file_descriptors])) - return result +def _are_dependencies_handled( + file_descriptor: descriptor.FileDescriptor, dependencies: Set[descriptor.Descriptor] +) -> bool: + """Returns True iff dependencies of descriptor are in dependencies.""" + return all( + dependency in dependencies for dependency in file_descriptor.dependencies + ) + + +def _order_dependencies( + file_descriptors: Set[descriptor.FileDescriptor], +) -> Sequence[descriptor.FileDescriptor]: + """Given a set of file descriptors, return them as an ordered list. + + Each file descriptor in the resulting list follows its dependencies. + + Args: + file_descriptors: a list of file descriptors. + + Returns: + file descriptors as an ordered list. + + Raises: + ValueError: if there is a circular dependency. + """ + # TODO(martinz): if you first create a graph, then you can calculate this in + # linear (instead of quadratic) time. However, let's hold off on that + # complexity for now. + result = [] + result_set = set() + progress = True + # We sort the file descriptors so that the list of descriptors returned + # is deterministic and the attr input to the DecodeProto* ops are stable. + remaining_file_descriptors = sorted(list(file_descriptors), key=lambda x: x.name) + while remaining_file_descriptors and progress: + failed = [] + progress = False + while remaining_file_descriptors: + file_descriptor = remaining_file_descriptors.pop() + if _are_dependencies_handled(file_descriptor, result_set): + result.append(file_descriptor) + result_set.add(file_descriptor) + progress = True + else: + failed.append(file_descriptor) + remaining_file_descriptors = failed + if remaining_file_descriptors: + # In theory, the module should not have even compiled. + raise ValueError( + "Circular dependency " + + str([file_desc.name for file_desc in remaining_file_descriptors]) + ) + return result def _get_dependencies_recursively( - initial_file_descriptors: Set[descriptor.FileDescriptor] + initial_file_descriptors: Set[descriptor.FileDescriptor], ) -> Set[descriptor.FileDescriptor]: - """Gets all dependencies of file_descriptors as a set (including itself).""" - file_descriptor_set = set(initial_file_descriptors) - boundary = list(initial_file_descriptors) - while boundary: - current = boundary.pop() - for dependency in current.dependencies: - if dependency not in file_descriptor_set: - file_descriptor_set.add(dependency) - boundary.append(dependency) - return file_descriptor_set + """Gets all dependencies of file_descriptors as a set (including itself).""" + file_descriptor_set = set(initial_file_descriptors) + boundary = list(initial_file_descriptors) + while boundary: + current = boundary.pop() + for dependency in current.dependencies: + if dependency not in file_descriptor_set: + file_descriptor_set.add(dependency) + boundary.append(dependency) + return file_descriptor_set def _create_file_descriptor_set_proto( - file_descriptor_list: Sequence[descriptor.FileDescriptor] + file_descriptor_list: Sequence[descriptor.FileDescriptor], ) -> descriptor_pb2.FileDescriptorSet: - """Creates a FileDescriptorSet proto from a list of file descriptors.""" - result = descriptor_pb2.FileDescriptorSet() - for file_descriptor in file_descriptor_list: - file_descriptor.CopyToProto(result.file.add()) - return result + """Creates a FileDescriptorSet proto from a list of file descriptors.""" + result = descriptor_pb2.FileDescriptorSet() + for file_descriptor in file_descriptor_list: + file_descriptor.CopyToProto(result.file.add()) + return result def _get_initial_file_descriptor_set( - descriptor_type: descriptor.Descriptor, - field_names: Sequence[str]) -> Set[descriptor.FileDescriptor]: - """Gets a set of file descriptors for a descriptor and extensions.""" - result = set() - result.add(descriptor_type.file) - for field_name in field_names: - if path.is_extension(field_name): - extension_field = descriptor_type.file.pool.FindExtensionByName( - path.get_raw_extension_name(field_name)) - extension_file = extension_field.file - if extension_file not in result: - result.add(extension_file) - return result + descriptor_type: descriptor.Descriptor, field_names: Sequence[str] +) -> Set[descriptor.FileDescriptor]: + """Gets a set of file descriptors for a descriptor and extensions.""" + result = set() + result.add(descriptor_type.file) + for field_name in field_names: + if path.is_extension(field_name): + extension_field = descriptor_type.file.pool.FindExtensionByName( + path.get_raw_extension_name(field_name) + ) + extension_file = extension_field.file + if extension_file not in result: + result.add(extension_file) + return result def get_file_descriptor_set_proto( - descriptor_type: descriptor.Descriptor, - field_names: Sequence[str]) -> descriptor_pb2.FileDescriptorSet: - """Returns a FileDescriptorSet for parsing field_names in a descriptor_type. - - The FileDescriptorSet has file descriptors for the file of the - descriptor_type, the files of any extensions in field_names, and any files - they depend upon. - - Args: - descriptor_type: the type to be parsed. - field_names: fields to be parsed. - """ - return _create_file_descriptor_set_proto( - _order_dependencies( - _get_dependencies_recursively( - _get_initial_file_descriptor_set(descriptor_type, field_names)))) + descriptor_type: descriptor.Descriptor, field_names: Sequence[str] +) -> descriptor_pb2.FileDescriptorSet: + """Returns a FileDescriptorSet for parsing field_names in a descriptor_type. + + The FileDescriptorSet has file descriptors for the file of the + descriptor_type, the files of any extensions in field_names, and any files + they depend upon. + + Args: + descriptor_type: the type to be parsed. + field_names: fields to be parsed. + """ + return _create_file_descriptor_set_proto( + _order_dependencies( + _get_dependencies_recursively( + _get_initial_file_descriptor_set(descriptor_type, field_names) + ) + ) + ) diff --git a/struct2tensor/ops/file_descriptor_set_test.py b/struct2tensor/ops/file_descriptor_set_test.py index 6403d3a..d694127 100644 --- a/struct2tensor/ops/file_descriptor_set_test.py +++ b/struct2tensor/ops/file_descriptor_set_test.py @@ -22,24 +22,25 @@ # it must be linked in and imported so that the extension can be found in the # pool. from struct2tensor.test import ( - test_map_pb2, + test_map_pb2, ) def _get_base_directory(): - return "" # pylint: disable=unreachable + return "" # pylint: disable=unreachable class FileDescriptorSetTest(absltest.TestCase): - - def test_get_file_descriptor_set_proto_simple_test_map(self): - file_set_proto = file_descriptor_set.get_file_descriptor_set_proto( - test_map_pb2.SubMessage.DESCRIPTOR, []) - self.assertLen(file_set_proto.file, 1) - self.assertEqual( - file_set_proto.file[0].name, - _get_base_directory() + "struct2tensor/test/test_map.proto") + def test_get_file_descriptor_set_proto_simple_test_map(self): + file_set_proto = file_descriptor_set.get_file_descriptor_set_proto( + test_map_pb2.SubMessage.DESCRIPTOR, [] + ) + self.assertLen(file_set_proto.file, 1) + self.assertEqual( + file_set_proto.file[0].name, + _get_base_directory() + "struct2tensor/test/test_map.proto", + ) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/ops/gen_decode_proto_map_op.py b/struct2tensor/ops/gen_decode_proto_map_op.py index 1559c62..67063af 100644 --- a/struct2tensor/ops/gen_decode_proto_map_op.py +++ b/struct2tensor/ops/gen_decode_proto_map_op.py @@ -17,7 +17,8 @@ from tensorflow.python.platform import resource_loader decode_proto_map_module = load_library.load_op_library( - resource_loader.get_path_to_datafile('_decode_proto_map_op.so')) + resource_loader.get_path_to_datafile("_decode_proto_map_op.so") +) decode_proto_map = decode_proto_map_module.decode_proto_map decode_proto_map_v2 = decode_proto_map_module.decode_proto_map_v2 diff --git a/struct2tensor/ops/gen_decode_proto_sparse.py b/struct2tensor/ops/gen_decode_proto_sparse.py index 35a6f15..86a7d1b 100644 --- a/struct2tensor/ops/gen_decode_proto_sparse.py +++ b/struct2tensor/ops/gen_decode_proto_sparse.py @@ -17,7 +17,8 @@ from tensorflow.python.platform import resource_loader decode_proto_sparse_module = load_library.load_op_library( - resource_loader.get_path_to_datafile('_decode_proto_sparse_op.so')) + resource_loader.get_path_to_datafile("_decode_proto_sparse_op.so") +) decode_proto_sparse_v2 = decode_proto_sparse_module.decode_proto_sparse_v2 decode_proto_sparse_v3 = decode_proto_sparse_module.decode_proto_sparse_v3 diff --git a/struct2tensor/ops/gen_equi_join_any_indices.py b/struct2tensor/ops/gen_equi_join_any_indices.py index 5af0dcf..e733932 100644 --- a/struct2tensor/ops/gen_equi_join_any_indices.py +++ b/struct2tensor/ops/gen_equi_join_any_indices.py @@ -17,5 +17,6 @@ from tensorflow.python.platform import resource_loader equi_join_any_indices_module = load_library.load_op_library( - resource_loader.get_path_to_datafile('_equi_join_any_indices_op.so')) + resource_loader.get_path_to_datafile("_equi_join_any_indices_op.so") +) equi_join_any_indices = equi_join_any_indices_module.equi_join_any_indices diff --git a/struct2tensor/ops/gen_equi_join_indices.py b/struct2tensor/ops/gen_equi_join_indices.py index 7ea9439..8f4678b 100644 --- a/struct2tensor/ops/gen_equi_join_indices.py +++ b/struct2tensor/ops/gen_equi_join_indices.py @@ -17,5 +17,6 @@ from tensorflow.python.platform import resource_loader equi_join_indices_module = load_library.load_op_library( - resource_loader.get_path_to_datafile('_equi_join_indices_op.so')) + resource_loader.get_path_to_datafile("_equi_join_indices_op.so") +) equi_join_indices = equi_join_indices_module.equi_join_indices diff --git a/struct2tensor/ops/gen_parquet_dataset.py b/struct2tensor/ops/gen_parquet_dataset.py index 92afe49..1dfb823 100644 --- a/struct2tensor/ops/gen_parquet_dataset.py +++ b/struct2tensor/ops/gen_parquet_dataset.py @@ -17,6 +17,7 @@ from tensorflow.python.platform import resource_loader parquet_dataset_module = load_library.load_op_library( - resource_loader.get_path_to_datafile('_parquet_dataset_op.so')) + resource_loader.get_path_to_datafile("_parquet_dataset_op.so") +) parquet_dataset = parquet_dataset_module.parquet_dataset diff --git a/struct2tensor/ops/gen_run_length_before.py b/struct2tensor/ops/gen_run_length_before.py index 6c6d158..ee13b19 100644 --- a/struct2tensor/ops/gen_run_length_before.py +++ b/struct2tensor/ops/gen_run_length_before.py @@ -17,5 +17,6 @@ from tensorflow.python.platform import resource_loader run_length_before_module = load_library.load_op_library( - resource_loader.get_path_to_datafile('_run_length_before_op.so')) + resource_loader.get_path_to_datafile("_run_length_before_op.so") +) run_length_before = run_length_before_module.run_length_before diff --git a/struct2tensor/ops/struct2tensor_ops.py b/struct2tensor/ops/struct2tensor_ops.py index af6ac78..ba6dacd 100644 --- a/struct2tensor/ops/struct2tensor_ops.py +++ b/struct2tensor/ops/struct2tensor_ops.py @@ -13,7 +13,6 @@ # limitations under the License. """Utilities for manipulating prensors.""" - from typing import NamedTuple, Optional, Sequence, Tuple import tensorflow as tf @@ -21,30 +20,30 @@ from struct2tensor import path from struct2tensor.ops import ( - file_descriptor_set, - gen_decode_proto_map_op, - gen_decode_proto_sparse, - gen_equi_join_any_indices, - gen_equi_join_indices, - gen_run_length_before, + file_descriptor_set, + gen_decode_proto_map_op, + gen_decode_proto_sparse, + gen_equi_join_any_indices, + gen_equi_join_indices, + gen_run_length_before, ) def _get_dtype_from_cpp_type(cpp_type: int) -> tf.DType: - """Converts a cpp type in FieldDescriptor to the appropriate dtype.""" - library = { - descriptor.FieldDescriptor.CPPTYPE_INT32: tf.int32, - descriptor.FieldDescriptor.CPPTYPE_INT64: tf.int64, - descriptor.FieldDescriptor.CPPTYPE_UINT32: tf.uint32, - descriptor.FieldDescriptor.CPPTYPE_UINT64: tf.uint64, - descriptor.FieldDescriptor.CPPTYPE_DOUBLE: tf.float64, - descriptor.FieldDescriptor.CPPTYPE_FLOAT: tf.float32, - descriptor.FieldDescriptor.CPPTYPE_BOOL: tf.bool, - descriptor.FieldDescriptor.CPPTYPE_ENUM: tf.int32, - descriptor.FieldDescriptor.CPPTYPE_STRING: tf.string, - descriptor.FieldDescriptor.CPPTYPE_MESSAGE: tf.string - } - return library[cpp_type] + """Converts a cpp type in FieldDescriptor to the appropriate dtype.""" + library = { + descriptor.FieldDescriptor.CPPTYPE_INT32: tf.int32, + descriptor.FieldDescriptor.CPPTYPE_INT64: tf.int64, + descriptor.FieldDescriptor.CPPTYPE_UINT32: tf.uint32, + descriptor.FieldDescriptor.CPPTYPE_UINT64: tf.uint64, + descriptor.FieldDescriptor.CPPTYPE_DOUBLE: tf.float64, + descriptor.FieldDescriptor.CPPTYPE_FLOAT: tf.float32, + descriptor.FieldDescriptor.CPPTYPE_BOOL: tf.bool, + descriptor.FieldDescriptor.CPPTYPE_ENUM: tf.int32, + descriptor.FieldDescriptor.CPPTYPE_STRING: tf.string, + descriptor.FieldDescriptor.CPPTYPE_MESSAGE: tf.string, + } + return library[cpp_type] # A named tuple for parse_full_message_level and parse_message_level @@ -55,10 +54,10 @@ def _get_dtype_from_cpp_type(cpp_type: int) -> tf.DType: # value: tf.Tensor # index: tf.Tensor class _ParsedField(NamedTuple): - field_name: str - field_descriptor: Optional[descriptor.FieldDescriptor] - value: tf.Tensor - index: tf.Tensor + field_name: str + field_descriptor: Optional[descriptor.FieldDescriptor] + value: tf.Tensor + index: tf.Tensor def parse_full_message_level( @@ -66,53 +65,55 @@ def parse_full_message_level( descriptor_type: descriptor.Descriptor, message_format: str = "binary", backing_str_tensor: Optional[tf.Tensor] = None, - honor_proto3_optional_semantics: bool = False) -> Sequence[_ParsedField]: - """Parses all of the fields at a level of a message. - - If there is a field with a message type, it is parsed as a string. Then, the - function can be applied recursively. - Note: this will not extract extensions. - - Args: - tensor_of_protos: a 1-D tensor of strings of protocol buffers. - descriptor_type: a descriptor for the protocol buffer to parse. See - https://github.com/protocolbuffers/protobuf/blob/master/python/google/protobuf/descriptor.py - message_format: Indicates the format of the protocol buffer: is one of - 'text' or 'binary'. - backing_str_tensor: a string tensor representing the root serialized proto. - This is passed to keep string_views of the tensor valid for all children - of the root expression - honor_proto3_optional_semantics: if True, and if a proto3 primitive optional - field without the presence semantic (i.e. the field is without the - "optional" or "repeated" label) is requested to be parsed, it will always - have a value for each input parent message. If a value is not present on - wire, the default value (0 or "") will be used. - - Returns: - list of named tuples, one per field_name in field_names: - field_name: the string from field_names. - field_descriptor: descriptor_type.fields_by_name[field_name] - value: a 1-D tensor of the values from the field field_name. - index: an index, such that for all i, tensor_of_protos[index[i]] has a - value value[i]. Note that sometimes index[i]=index[i+1], implying a - repeated field field_name. - """ - field_names = [field.name for field in descriptor_type.fields] - return parse_message_level( - tensor_of_protos, - descriptor_type, - field_names, - message_format, - backing_str_tensor=backing_str_tensor, - honor_proto3_optional_semantics=honor_proto3_optional_semantics) - - -def _get_field_descriptor(descriptor_type: descriptor.Descriptor, - field_name: str): - if path.is_extension(field_name): - return descriptor_type.file.pool.FindExtensionByName( - path.get_raw_extension_name(field_name)) - return descriptor_type.fields_by_name[field_name] + honor_proto3_optional_semantics: bool = False, +) -> Sequence[_ParsedField]: + """Parses all of the fields at a level of a message. + + If there is a field with a message type, it is parsed as a string. Then, the + function can be applied recursively. + Note: this will not extract extensions. + + Args: + tensor_of_protos: a 1-D tensor of strings of protocol buffers. + descriptor_type: a descriptor for the protocol buffer to parse. See + https://github.com/protocolbuffers/protobuf/blob/master/python/google/protobuf/descriptor.py + message_format: Indicates the format of the protocol buffer: is one of + 'text' or 'binary'. + backing_str_tensor: a string tensor representing the root serialized proto. + This is passed to keep string_views of the tensor valid for all children + of the root expression + honor_proto3_optional_semantics: if True, and if a proto3 primitive optional + field without the presence semantic (i.e. the field is without the + "optional" or "repeated" label) is requested to be parsed, it will always + have a value for each input parent message. If a value is not present on + wire, the default value (0 or "") will be used. + + Returns: + list of named tuples, one per field_name in field_names: + field_name: the string from field_names. + field_descriptor: descriptor_type.fields_by_name[field_name] + value: a 1-D tensor of the values from the field field_name. + index: an index, such that for all i, tensor_of_protos[index[i]] has a + value value[i]. Note that sometimes index[i]=index[i+1], implying a + repeated field field_name. + """ + field_names = [field.name for field in descriptor_type.fields] + return parse_message_level( + tensor_of_protos, + descriptor_type, + field_names, + message_format, + backing_str_tensor=backing_str_tensor, + honor_proto3_optional_semantics=honor_proto3_optional_semantics, + ) + + +def _get_field_descriptor(descriptor_type: descriptor.Descriptor, field_name: str): + if path.is_extension(field_name): + return descriptor_type.file.pool.FindExtensionByName( + path.get_raw_extension_name(field_name) + ) + return descriptor_type.fields_by_name[field_name] def parse_message_level( @@ -121,157 +122,161 @@ def parse_message_level( field_names: Sequence[str], message_format: str = "binary", backing_str_tensor: Optional[tf.Tensor] = None, - honor_proto3_optional_semantics: bool = False) -> Sequence[_ParsedField]: - """Parses a subset of the fields at a level of a message. - - If there is a field with a message type, it is parsed as a string. Then, the - function can be applied recursively. - - Args: - tensor_of_protos: a 1-D tensor of strings of protocol buffers. - descriptor_type: a descriptor for the protocol buffer to parse. See - https://github.com/protocolbuffers/protobuf/blob/master/python/google/protobuf/descriptor.py - field_names: the names of the fields to parse. - message_format: Indicates the format of the protocol buffer: is one of - 'text' or 'binary'. - backing_str_tensor: a possible string tensor backing the string_view for - intermediate serialized protos. - honor_proto3_optional_semantics: if True, and if a proto3 primitive optional - field without the presence semantic (i.e. the field is without the - "optional" or "repeated" label) is requested to be parsed, it will always - have a value for each input parent message. If a value is not present on - wire, the default value (0 or "") will be used. - - Returns: - list of named _ParsedField, one per field_name in field_names: - field_name: the string from field_names. - field_descriptor: descriptor_type.fields_by_name[field_name] - value: a 1-D tensor of the values from the field field_name. - index: an index, such that for all i, tensor_of_protos[index[i]] has a - value value[i]. Note that sometimes index[i]=index[i+1], implying a - repeated field field_name. - - """ - if not field_names: - return [] - # We sort the field names so that the input attr to DecodeProtoSparseV2 op - # is deterministic. - field_names = sorted(field_names) - message_type = descriptor_type.full_name - descriptor_set = file_descriptor_set.get_file_descriptor_set_proto( - descriptor_type, field_names) - descriptor_literal = descriptor_set.SerializeToString() - # TODO(martinz): catch KeyError and give a better error. - field_descriptors = [ - _get_field_descriptor(descriptor_type, field_name) - for field_name in field_names - ] - output_types = [ - _get_dtype_from_cpp_type(field_descriptor.cpp_type) - for field_descriptor in field_descriptors - ] - if tf.is_tensor(backing_str_tensor): - assert message_format == "binary", ( - "message_format must be 'binary' if a backing_str_tensor is provided") - backing_str_tensor = [backing_str_tensor] - else: - backing_str_tensor = [] - values, indices = gen_decode_proto_sparse.decode_proto_sparse_v4( - tensor_of_protos, - backing_str_tensor, - descriptor_literal=descriptor_literal, - message_type=message_type, - num_fields=len(field_names), - field_names=list(field_names), - output_types=output_types, - message_format=message_format, - honor_proto3_optional_semantics=honor_proto3_optional_semantics) - - result = [] - for field_name, field_descriptor, value, index in zip(field_names, - field_descriptors, - values, indices): - result.append( - _ParsedField( - field_name=field_name, - field_descriptor=field_descriptor, - value=value, - index=index)) - - return result + honor_proto3_optional_semantics: bool = False, +) -> Sequence[_ParsedField]: + """Parses a subset of the fields at a level of a message. + + If there is a field with a message type, it is parsed as a string. Then, the + function can be applied recursively. + + Args: + tensor_of_protos: a 1-D tensor of strings of protocol buffers. + descriptor_type: a descriptor for the protocol buffer to parse. See + https://github.com/protocolbuffers/protobuf/blob/master/python/google/protobuf/descriptor.py + field_names: the names of the fields to parse. + message_format: Indicates the format of the protocol buffer: is one of + 'text' or 'binary'. + backing_str_tensor: a possible string tensor backing the string_view for + intermediate serialized protos. + honor_proto3_optional_semantics: if True, and if a proto3 primitive optional + field without the presence semantic (i.e. the field is without the + "optional" or "repeated" label) is requested to be parsed, it will always + have a value for each input parent message. If a value is not present on + wire, the default value (0 or "") will be used. + + Returns: + list of named _ParsedField, one per field_name in field_names: + field_name: the string from field_names. + field_descriptor: descriptor_type.fields_by_name[field_name] + value: a 1-D tensor of the values from the field field_name. + index: an index, such that for all i, tensor_of_protos[index[i]] has a + value value[i]. Note that sometimes index[i]=index[i+1], implying a + repeated field field_name. + + """ + if not field_names: + return [] + # We sort the field names so that the input attr to DecodeProtoSparseV2 op + # is deterministic. + field_names = sorted(field_names) + message_type = descriptor_type.full_name + descriptor_set = file_descriptor_set.get_file_descriptor_set_proto( + descriptor_type, field_names + ) + descriptor_literal = descriptor_set.SerializeToString() + # TODO(martinz): catch KeyError and give a better error. + field_descriptors = [ + _get_field_descriptor(descriptor_type, field_name) for field_name in field_names + ] + output_types = [ + _get_dtype_from_cpp_type(field_descriptor.cpp_type) + for field_descriptor in field_descriptors + ] + if tf.is_tensor(backing_str_tensor): + assert message_format == "binary", ( + "message_format must be 'binary' if a backing_str_tensor is provided" + ) + backing_str_tensor = [backing_str_tensor] + else: + backing_str_tensor = [] + values, indices = gen_decode_proto_sparse.decode_proto_sparse_v4( + tensor_of_protos, + backing_str_tensor, + descriptor_literal=descriptor_literal, + message_type=message_type, + num_fields=len(field_names), + field_names=list(field_names), + output_types=output_types, + message_format=message_format, + honor_proto3_optional_semantics=honor_proto3_optional_semantics, + ) + + result = [] + for field_name, field_descriptor, value, index in zip( + field_names, field_descriptors, values, indices + ): + result.append( + _ParsedField( + field_name=field_name, + field_descriptor=field_descriptor, + value=value, + index=index, + ) + ) + + return result def run_length_before(a: tf.Tensor) -> tf.Tensor: - r"""Returns the run length of each set of elements in a vector. + r"""Returns the run length of each set of elements in a vector. - Args: - a: a 1D int64 tensor. This assumes that for all a_i, a_j, if i <= j, then - a_i <= a_j. + Args: + a: a 1D int64 tensor. This assumes that for all a_i, a_j, if i <= j, then + a_i <= a_j. - Returns: - 1D int64 tensor [b_0,...,b_n] where b_n := \sum_{i=0}^{n-1} I(a_i=a_n) - """ - return gen_run_length_before.run_length_before(a) + Returns: + 1D int64 tensor [b_0,...,b_n] where b_n := \sum_{i=0}^{n-1} I(a_i=a_n) + """ + return gen_run_length_before.run_length_before(a) -def create_sparse_tensor_for_repeated(parent_index: tf.Tensor, - values: tf.Tensor, dense_shape: tf.Tensor - ) -> tf.SparseTensor: - """Helps to get the sparse tensor for a repeated PrensorField. +def create_sparse_tensor_for_repeated( + parent_index: tf.Tensor, values: tf.Tensor, dense_shape: tf.Tensor +) -> tf.SparseTensor: + """Helps to get the sparse tensor for a repeated PrensorField. - Args: - parent_index: a 1D int64 tensor, which has a pointer to the parent message. - values: a 1D tensor, which has the primitive values associated with a field. - dense_shape: a 1D int64 tensor, representing the dense shape. + Args: + parent_index: a 1D int64 tensor, which has a pointer to the parent message. + values: a 1D tensor, which has the primitive values associated with a field. + dense_shape: a 1D int64 tensor, representing the dense shape. - Returns: - A Sparse Tensor representing a ragged array. - """ - run_length_before_tensor = run_length_before(parent_index) - indices = tf.stack([parent_index, run_length_before_tensor], axis=1) - return tf.SparseTensor( - indices=indices, values=values, dense_shape=dense_shape) + Returns: + A Sparse Tensor representing a ragged array. + """ + run_length_before_tensor = run_length_before(parent_index) + indices = tf.stack([parent_index, run_length_before_tensor], axis=1) + return tf.SparseTensor(indices=indices, values=values, dense_shape=dense_shape) def equi_join_indices(a: tf.Tensor, b: tf.Tensor) -> tf.Tensor: - """A custom op such that a broadcast operation can be done. + """A custom op such that a broadcast operation can be done. - Args: - a: a tensor that is an int64 vector, where for all i, a[i] <= a[i+1] - b: a tensor that is an int64 vector, where for all i, b[i] <= b[i+1] + Args: + a: a tensor that is an int64 vector, where for all i, a[i] <= a[i+1] + b: a tensor that is an int64 vector, where for all i, b[i] <= b[i+1] - Returns: - [index_a, index_b] where: - 1. For every k, a[index_a[k]] = b[index_b[k]] - 2. for every i,j, iff a[i]==b[j], then there exists a k where - index_a[k]=i and index_b[k]=j. - 3. Moreover, for any k, k' where k < k', - index_a[k] <= index_a[k'], and if index_a[k] == index_a[k'], then - index_b[k] <= index_b[k']. - """ - return gen_equi_join_indices.equi_join_indices(a, b) + Returns: + [index_a, index_b] where: + 1. For every k, a[index_a[k]] = b[index_b[k]] + 2. for every i,j, iff a[i]==b[j], then there exists a k where + index_a[k]=i and index_b[k]=j. + 3. Moreover, for any k, k' where k < k', + index_a[k] <= index_a[k'], and if index_a[k] == index_a[k'], then + index_b[k] <= index_b[k']. + """ + return gen_equi_join_indices.equi_join_indices(a, b) def equi_join_any_indices(a: tf.Tensor, b: tf.Tensor) -> tf.Tensor: - """A custom op such that a broadcast subtree operation can be done. + """A custom op such that a broadcast subtree operation can be done. - Similar to `equi_join_indices`, except this does not assume `a` and `b` are - monotonically increasing. Prefer to use equi_join_indices if possible. + Similar to `equi_join_indices`, except this does not assume `a` and `b` are + monotonically increasing. Prefer to use equi_join_indices if possible. - Args: - a: a tensor that is an int64 vector - b: a tensor that is an int64 vector + Args: + a: a tensor that is an int64 vector + b: a tensor that is an int64 vector - Returns: - [index_a, index_b] where: - 1. For every k, a[index_a[k]] = b[index_b[k]] - 2. for every i,j, iff a[i]==b[j], then there exists a k where - index_a[k]=i and index_b[k]=j. - 3. Moreover, for any k, k' where k < k', - index_a[k] <= index_a[k'], and if index_a[k] == index_a[k'], then - index_b[k] <= index_b[k']. - """ - return gen_equi_join_any_indices.equi_join_any_indices(a, b) + Returns: + [index_a, index_b] where: + 1. For every k, a[index_a[k]] = b[index_b[k]] + 2. for every i,j, iff a[i]==b[j], then there exists a k where + index_a[k]=i and index_b[k]=j. + 3. Moreover, for any k, k' where k < k', + index_a[k] <= index_a[k'], and if index_a[k] == index_a[k'], then + index_b[k] <= index_b[k']. + """ + return gen_equi_join_any_indices.equi_join_any_indices(a, b) def parse_proto_map( @@ -279,38 +284,46 @@ def parse_proto_map( map_entry_parent_indices, map_entry_descriptor: descriptor.Descriptor, keys_needed: Sequence[str], - backing_str_tensor: Optional[tf.Tensor] = None + backing_str_tensor: Optional[tf.Tensor] = None, ) -> Sequence[Tuple[tf.Tensor, tf.Tensor]]: - """A custom op to parse serialized Protobuf map entries. - - Args: - map_entries: a 1D string tensor that contains serialized map entry - sub-messages. - map_entry_parent_indices: a 1D int64 tensor of the same length as - map_entries. map_entry_parent_indices[i] == j means map_entries[i] belongs - to the j-th map. - map_entry_descriptor: the proto descriptor of the map entry sub-message. - keys_needed: keys that are needed to be looked up in the map. If the map's - keys are integers, then these strings will be parsed as integers in - decimal. If the map's keys are booleans, then only "0" and "1" are - expected. - backing_str_tensor: a possible string tensor backing the string_view for - intermediate serialized protos. - - Returns: - A list of tuples one for each key in `keys_needed`. In each tuple, the first - term contains decoded values; the second term contains the parent indices - for the values. - """ - keys_needed_as_list = list(keys_needed) - value_fd = map_entry_descriptor.fields_by_name["value"] - - backing_str_tensor = [backing_str_tensor] if tf.is_tensor(backing_str_tensor) else [] - - values, parent_indices = gen_decode_proto_map_op.decode_proto_map_v2( - map_entries, map_entry_parent_indices, backing_str_tensor, - map_entry_descriptor.full_name, keys_needed_as_list, - len(keys_needed_as_list), _get_dtype_from_cpp_type(value_fd.cpp_type), - file_descriptor_set.get_file_descriptor_set_proto( - map_entry_descriptor, ["key", "value"]).SerializeToString()) - return list(zip(values, parent_indices)) + """A custom op to parse serialized Protobuf map entries. + + Args: + map_entries: a 1D string tensor that contains serialized map entry + sub-messages. + map_entry_parent_indices: a 1D int64 tensor of the same length as + map_entries. map_entry_parent_indices[i] == j means map_entries[i] belongs + to the j-th map. + map_entry_descriptor: the proto descriptor of the map entry sub-message. + keys_needed: keys that are needed to be looked up in the map. If the map's + keys are integers, then these strings will be parsed as integers in + decimal. If the map's keys are booleans, then only "0" and "1" are + expected. + backing_str_tensor: a possible string tensor backing the string_view for + intermediate serialized protos. + + Returns: + A list of tuples one for each key in `keys_needed`. In each tuple, the first + term contains decoded values; the second term contains the parent indices + for the values. + """ + keys_needed_as_list = list(keys_needed) + value_fd = map_entry_descriptor.fields_by_name["value"] + + backing_str_tensor = ( + [backing_str_tensor] if tf.is_tensor(backing_str_tensor) else [] + ) + + values, parent_indices = gen_decode_proto_map_op.decode_proto_map_v2( + map_entries, + map_entry_parent_indices, + backing_str_tensor, + map_entry_descriptor.full_name, + keys_needed_as_list, + len(keys_needed_as_list), + _get_dtype_from_cpp_type(value_fd.cpp_type), + file_descriptor_set.get_file_descriptor_set_proto( + map_entry_descriptor, ["key", "value"] + ).SerializeToString(), + ) + return list(zip(values, parent_indices)) diff --git a/struct2tensor/ops/struct2tensor_ops_test.py b/struct2tensor/ops/struct2tensor_ops_test.py index d96ad9b..289bcba 100644 --- a/struct2tensor/ops/struct2tensor_ops_test.py +++ b/struct2tensor/ops/struct2tensor_ops_test.py @@ -19,714 +19,785 @@ import tensorflow as tf from absl.testing import absltest, parameterized from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor.ops import struct2tensor_ops from struct2tensor.test import ( - test_extension_pb2, - test_map_pb2, - test_pb2, + test_extension_pb2, + test_map_pb2, + test_pb2, ) INDEX = "index" VALUE = "value" -_EQUIJOIN_TEST_CASES = [{ - "testcase_name": "simple", - "a": [0, 0, 1, 1, 2, 3, 4], - "b": [0, 0, 2, 2, 3], - "expected_index_a": [0, 0, 1, 1, 4, 4, 5], - "expected_index_b": [0, 1, 0, 1, 2, 3, 4] -}, { - "testcase_name": "simple_2", - "a": [0, 1, 1, 2], - "b": [0, 1, 2], - "expected_index_a": [0, 1, 2, 3], - "expected_index_b": [0, 1, 1, 2] -}, { - "testcase_name": "empty", - "a": [], - "b": [0, 1, 2], - "expected_index_a": [], - "expected_index_b": [] -}, { - "testcase_name": "empty_2", - "a": [0, 1, 1, 2], - "b": [], - "expected_index_a": [], - "expected_index_b": [] -}, { - "testcase_name": "both_empty", - "a": [], - "b": [], - "expected_index_a": [], - "expected_index_b": [] -}, { - "testcase_name": "no_overlap", - "a": [0, 1, 1, 2], - "b": [3, 4, 5], - "expected_index_a": [], - "expected_index_b": [] -}, { - "testcase_name": "broadcast", - "a": [0, 1, 1], - "b": [0, 1, 2], - "expected_index_a": [0, 1, 2], - "expected_index_b": [0, 1, 1] -}] +_EQUIJOIN_TEST_CASES = [ + { + "testcase_name": "simple", + "a": [0, 0, 1, 1, 2, 3, 4], + "b": [0, 0, 2, 2, 3], + "expected_index_a": [0, 0, 1, 1, 4, 4, 5], + "expected_index_b": [0, 1, 0, 1, 2, 3, 4], + }, + { + "testcase_name": "simple_2", + "a": [0, 1, 1, 2], + "b": [0, 1, 2], + "expected_index_a": [0, 1, 2, 3], + "expected_index_b": [0, 1, 1, 2], + }, + { + "testcase_name": "empty", + "a": [], + "b": [0, 1, 2], + "expected_index_a": [], + "expected_index_b": [], + }, + { + "testcase_name": "empty_2", + "a": [0, 1, 1, 2], + "b": [], + "expected_index_a": [], + "expected_index_b": [], + }, + { + "testcase_name": "both_empty", + "a": [], + "b": [], + "expected_index_a": [], + "expected_index_b": [], + }, + { + "testcase_name": "no_overlap", + "a": [0, 1, 1, 2], + "b": [3, 4, 5], + "expected_index_a": [], + "expected_index_b": [], + }, + { + "testcase_name": "broadcast", + "a": [0, 1, 1], + "b": [0, 1, 2], + "expected_index_a": [0, 1, 2], + "expected_index_b": [0, 1, 1], + }, +] def _parse_full_message_level_as_dict(proto_list): - serialized = [proto.SerializeToString() for proto in proto_list] - parsed_field_list = struct2tensor_ops.parse_full_message_level( - tf.constant(serialized), proto_list[0].DESCRIPTOR) - parsed_field_dict = {} - for parsed_field in parsed_field_list: - parsed_field_dict[parsed_field.field_name] = parsed_field - return parsed_field_dict + serialized = [proto.SerializeToString() for proto in proto_list] + parsed_field_list = struct2tensor_ops.parse_full_message_level( + tf.constant(serialized), proto_list[0].DESCRIPTOR + ) + parsed_field_dict = {} + for parsed_field in parsed_field_list: + parsed_field_dict[parsed_field.field_name] = parsed_field + return parsed_field_dict def _make_dict_runnable(level_as_dict): - """Prepares output of parse_full_message_level_as_dict for evaluate.""" - result = {} - for key, value in level_as_dict.items(): - local_dict = {} - local_dict[INDEX] = value.index - local_dict[VALUE] = value.value - result[key] = local_dict - return result + """Prepares output of parse_full_message_level_as_dict for evaluate.""" + result = {} + for key, value in level_as_dict.items(): + local_dict = {} + local_dict[INDEX] = value.index + local_dict[VALUE] = value.value + result[key] = local_dict + return result def _get_full_message_level_runnable(proto_list): - return _make_dict_runnable(_parse_full_message_level_as_dict(proto_list)) + return _make_dict_runnable(_parse_full_message_level_as_dict(proto_list)) # TODO(martinz): test empty tensors for decode_proto_sparse more thoroughly. @test_util.run_all_in_graph_and_eager_modes class PrensorOpsTest(parameterized.TestCase, tf.test.TestCase): - - def test_out_of_order_fields(self): - fragments = [ - test_pb2.Event(query_token=["aaa"]).SerializeToString(), - test_pb2.Event(query_token=["bbb"]).SerializeToString(), - test_pb2.Event(event_id="abc").SerializeToString(), - test_pb2.Event(action_mask=[False, True]).SerializeToString(), - ] - # Test against all 4! permutations of fragments, and for each permutation - # test parsing all possible combination of 4 fields. - for indices in itertools.permutations(range(len(fragments))): - proto = b"".join([fragments[i] for i in indices]) - - for i in indices: - if i == 0: - expected_query_tokens = [b"aaa", b"bbb"] - break - if i == 1: - expected_query_tokens = [b"bbb", b"aaa"] - break - - # "query" is not on wire at all. - all_fields_to_parse = ["query_token", "event_id", "action_mask", "query"] - expected_field_value = { - "action_mask": [False, True], - "query_token": expected_query_tokens, - "event_id": [b"abc"], - "query": np.array([], dtype=object), - } - - for num_fields_to_parse in range(len(all_fields_to_parse)): - for comb in itertools.combinations( - all_fields_to_parse, num_fields_to_parse): - parsed_fields = struct2tensor_ops.parse_message_level( - [proto], test_pb2.Event.DESCRIPTOR, comb) - self.assertLen(parsed_fields, len(comb)) - for f in parsed_fields: + def test_out_of_order_fields(self): + fragments = [ + test_pb2.Event(query_token=["aaa"]).SerializeToString(), + test_pb2.Event(query_token=["bbb"]).SerializeToString(), + test_pb2.Event(event_id="abc").SerializeToString(), + test_pb2.Event(action_mask=[False, True]).SerializeToString(), + ] + # Test against all 4! permutations of fragments, and for each permutation + # test parsing all possible combination of 4 fields. + for indices in itertools.permutations(range(len(fragments))): + proto = b"".join([fragments[i] for i in indices]) + + for i in indices: + if i == 0: + expected_query_tokens = [b"aaa", b"bbb"] + break + if i == 1: + expected_query_tokens = [b"bbb", b"aaa"] + break + + # "query" is not on wire at all. + all_fields_to_parse = ["query_token", "event_id", "action_mask", "query"] + expected_field_value = { + "action_mask": [False, True], + "query_token": expected_query_tokens, + "event_id": [b"abc"], + "query": np.array([], dtype=object), + } + + for num_fields_to_parse in range(len(all_fields_to_parse)): + for comb in itertools.combinations( + all_fields_to_parse, num_fields_to_parse + ): + parsed_fields = struct2tensor_ops.parse_message_level( + [proto], test_pb2.Event.DESCRIPTOR, comb + ) + self.assertLen(parsed_fields, len(comb)) + for f in parsed_fields: + self.assertAllEqual( + expected_field_value[f.field_name], + f.value, + f"field: {f.field_name}, permutation: {indices}, field_to_parse: {comb}", + ) + + def test_out_of_order_repeated_fields_1(self): + # This is a 2-1-2 wire number pattern. + proto = ( + test_pb2.Event(query_token=["aaa"]).SerializeToString() + + test_pb2.Event(event_id="abc").SerializeToString() + + test_pb2.Event(query_token=["bbb"]).SerializeToString() + ) + expected_field_value = {"query_token": [b"aaa", b"bbb"], "event_id": [b"abc"]} + + for fields_to_parse in [ + ["query_token"], + ["event_id"], + ["query_token", "event_id"], + ]: + parsed_fields = struct2tensor_ops.parse_message_level( + [proto], test_pb2.Event.DESCRIPTOR, fields_to_parse + ) + for f in parsed_fields: + self.assertAllEqual(expected_field_value[f.field_name], f.value) + + def test_out_of_order_repeated_fields_2(self): + # This follows a 3-5-3 wire number pattern, where 3 and 4 parsed fields. + proto = ( + test_pb2.Event(query_token=["aaa"]).SerializeToString() + + test_pb2.Event(action_mask=[True]).SerializeToString() + + test_pb2.Event(query_token=["bbb"]).SerializeToString() + ) + expected_field_value = { + "query_token": [b"aaa", b"bbb"], + "action_mask": [True], + "action": [], + } + for fields_to_parse in [ + ["query_token"], + ["action_mask"], + ["query_token", "action_mask"], + ["query_token", "action"], + ["query_token", "action_mask", "action"], + ]: + parsed_fields = struct2tensor_ops.parse_message_level( + [proto], test_pb2.Event.DESCRIPTOR, fields_to_parse + ) + for f in parsed_fields: + expected_value = expected_field_value[f.field_name] + if expected_value: + self.assertAllEqual(expected_value, f.value) + + def test_out_of_order_repeated_fields_3(self): + # This follows a 3-5-3 wire number pattern, where 3 and 4 parsed fields. + proto = ( + test_pb2.AllSimple(repeated_string=["aaa"]).SerializeToString() + + test_pb2.AllSimple(repeated_int64=[12345]).SerializeToString() + + test_pb2.AllSimple(repeated_string=["bbb"]).SerializeToString() + ) + expected_field_value = { + "repeated_string": [b"aaa", b"bbb"], + "repeated_int64": [12345], + "repeated_int32": [], + "repeated_uint32": [], + } + + for fields_to_parse in [ + ["repeated_int64"], + ["repeated_string"], + ["repeated_string", "repeated_uint32", "repeated_int32"], + ]: + parsed_fields = struct2tensor_ops.parse_message_level( + [proto], test_pb2.AllSimple.DESCRIPTOR, fields_to_parse + ) + for f in parsed_fields: + self.assertAllEqual(expected_field_value[f.field_name], f.value) + + def test_parse_full_message_level_for_event(self): + event = test_pb2.Event() + event.event_id = "foo" + event.query = "query" + event.query_token.append("a") + event.query_token.append("b") + action0 = event.action.add() + action0.doc_id = "abc" + action1 = event.action.add() + event.user_info.age_in_years = 38 + event2 = test_pb2.Event() + action2 = event2.action.add() + action2.doc_id = "def" + parsed_field_dict = _parse_full_message_level_as_dict([event, event2]) + doc_id = parsed_field_dict["action"] + serialized_actions = [ + proto.SerializeToString() for proto in [action0, action1, action2] + ] + + self.assertAllEqual(doc_id.index, [0, 0, 1]) + self.assertAllEqual(doc_id.value, serialized_actions) + + def test_parse_full_message_level_for_simple_action_multiple(self): + """Test multiple messages.""" + as1 = test_pb2.AllSimple() + as1.optional_string = "a" + as1.repeated_string.append("b") + as1.repeated_string.append("c") + as2 = test_pb2.AllSimple() + as2.optional_string = "d" + as2.optional_int32 = 123 + as3 = test_pb2.AllSimple() + as3.repeated_string.append("d") + as3.repeated_string.append("e") + as3.optional_int32 = 123 + + parsed_field_dict = _parse_full_message_level_as_dict([as1, as2, as3]) + + doc_id = parsed_field_dict["repeated_string"] + self.assertAllEqual(doc_id.index, [0, 0, 2, 2]) + self.assertAllEqual(doc_id.value, [b"b", b"c", b"d", b"e"]) + + def test_parse_full_message_level_for_all_simple_repeated_repeated(self): + """Test five messages with every possible repeated field repeated.""" + all_simple = test_pb2.AllSimple() + all_simple.repeated_string.append("foo") + all_simple.repeated_string.append("foo2") + all_simple.repeated_int32.append(32) + all_simple.repeated_int32.append(322) + all_simple.repeated_uint32.append(123) + all_simple.repeated_uint32.append(1232) + all_simple.repeated_int64.append(123456) + all_simple.repeated_int64.append(1234562) + all_simple.repeated_uint64.append(123) + all_simple.repeated_uint64.append(1232) + all_simple.repeated_float.append(1.0) + all_simple.repeated_float.append(2.0) + all_simple.repeated_double.append(1.5) + all_simple.repeated_double.append(2.5) + result = _get_full_message_level_runnable( + [ + all_simple, + test_pb2.AllSimple(), + test_pb2.AllSimple(), + all_simple, + test_pb2.AllSimple(), + all_simple, + test_pb2.AllSimple(), + ] + ) + self.assertAllEqual(result["repeated_string"][INDEX], [0, 0, 3, 3, 5, 5]) + self.assertAllEqual( + result["repeated_string"][VALUE], + [b"foo", b"foo2", b"foo", b"foo2", b"foo", b"foo2"], + ) + self.assertAllEqual(result["repeated_int32"][INDEX], [0, 0, 3, 3, 5, 5]) + self.assertAllEqual( + result["repeated_int32"][VALUE], [32, 322, 32, 322, 32, 322] + ) + self.assertAllEqual(result["repeated_uint32"][INDEX], [0, 0, 3, 3, 5, 5]) + self.assertAllEqual( + result["repeated_uint32"][VALUE], [123, 1232, 123, 1232, 123, 1232] + ) + self.assertAllEqual(result["repeated_int64"][INDEX], [0, 0, 3, 3, 5, 5]) + self.assertAllEqual( + result["repeated_int64"][VALUE], + [123456, 1234562, 123456, 1234562, 123456, 1234562], + ) + self.assertAllEqual(result["repeated_uint64"][INDEX], [0, 0, 3, 3, 5, 5]) + self.assertAllEqual( + result["repeated_uint64"][VALUE], [123, 1232, 123, 1232, 123, 1232] + ) + self.assertAllEqual(result["repeated_float"][INDEX], [0, 0, 3, 3, 5, 5]) + self.assertAllEqual( + result["repeated_float"][VALUE], [1.0, 2.0, 1.0, 2.0, 1.0, 2.0] + ) + self.assertAllEqual(result["repeated_double"][INDEX], [0, 0, 3, 3, 5, 5]) + self.assertAllEqual( + result["repeated_double"][VALUE], [1.5, 2.5, 1.5, 2.5, 1.5, 2.5] + ) + + def test_parse_full_message_level_for_all_simple_repeated(self): + """Test a single message with every possible repeated field repeated.""" + all_simple = test_pb2.AllSimple() + all_simple.repeated_string.append("foo") + all_simple.repeated_string.append("foo2") + all_simple.repeated_int32.append(32) + all_simple.repeated_int32.append(322) + all_simple.repeated_uint32.append(123) + all_simple.repeated_uint32.append(1232) + all_simple.repeated_int64.append(123456) + all_simple.repeated_int64.append(1234562) + all_simple.repeated_uint64.append(123) + all_simple.repeated_uint64.append(1232) + all_simple.repeated_float.append(1.0) + all_simple.repeated_float.append(2.0) + all_simple.repeated_double.append(1.5) + all_simple.repeated_double.append(2.5) + result = _get_full_message_level_runnable([all_simple]) + self.assertAllEqual(result["repeated_string"][INDEX], [0, 0]) + self.assertAllEqual(result["repeated_string"][VALUE], [b"foo", b"foo2"]) + self.assertAllEqual(result["repeated_int32"][INDEX], [0, 0]) + self.assertAllEqual(result["repeated_int32"][VALUE], [32, 322]) + self.assertAllEqual(result["repeated_uint32"][INDEX], [0, 0]) + self.assertAllEqual(result["repeated_uint32"][VALUE], [123, 1232]) + self.assertAllEqual(result["repeated_int64"][INDEX], [0, 0]) + self.assertAllEqual(result["repeated_int64"][VALUE], [123456, 1234562]) + self.assertAllEqual(result["repeated_uint64"][INDEX], [0, 0]) + self.assertAllEqual(result["repeated_uint64"][VALUE], [123, 1232]) + self.assertAllEqual(result["repeated_float"][INDEX], [0, 0]) + self.assertAllEqual(result["repeated_float"][VALUE], [1.0, 2.0]) + self.assertAllEqual(result["repeated_double"][INDEX], [0, 0]) + self.assertAllEqual(result["repeated_double"][VALUE], [1.5, 2.5]) + + def test_parse_full_message_level_for_all_simple(self): + """Test a single message with every possible primitive field.""" + all_simple = test_pb2.AllSimple() + all_simple.optional_string = "foo" + all_simple.optional_int32 = -5 + all_simple.optional_uint32 = 2**31 + all_simple.optional_int64 = 100123 + all_simple.optional_uint64 = 2**63 + all_simple.optional_float = 6.5 + all_simple.optional_double = -7.0 + all_simple.repeated_string.append("foo") + all_simple.repeated_int32.append(32) + all_simple.repeated_uint32.append(123) + all_simple.repeated_int64.append(123456) + all_simple.repeated_uint64.append(123) + all_simple.repeated_float.append(1.0) + all_simple.repeated_double.append(1.5) + runnable = _get_full_message_level_runnable([all_simple]) + self.assertLen(runnable["optional_string"][INDEX].shape.dims, 1) + self.assertLen(runnable["optional_string"][VALUE].shape.dims, 1) + self.assertLen(runnable["repeated_string"][INDEX].shape.dims, 1) + self.assertLen(runnable["repeated_string"][VALUE].shape.dims, 1) + + result = runnable + self.assertAllEqual(result["optional_string"][INDEX], [0]) + self.assertAllEqual(result["optional_string"][VALUE], [b"foo"]) + self.assertAllEqual(result["optional_int32"][INDEX], [0]) + self.assertAllEqual(result["optional_int32"][VALUE], [-5]) + self.assertAllEqual(result["optional_uint32"][INDEX], [0]) + self.assertAllEqual(result["optional_uint32"][VALUE], [2**31]) + self.assertAllEqual(result["optional_int64"][INDEX], [0]) + self.assertAllEqual(result["optional_int64"][VALUE], [100123]) + self.assertAllEqual(result["optional_uint64"][INDEX], [0]) + self.assertAllEqual(result["optional_uint64"][VALUE], [2**63]) + self.assertAllEqual(result["optional_float"][INDEX], [0]) + self.assertAllEqual(result["optional_float"][VALUE], [6.5]) + self.assertAllEqual(result["optional_double"][INDEX], [0]) + self.assertAllEqual(result["optional_double"][VALUE], [-7.0]) + # TODO(martinz): test the repeated fields too. + + def test_parse_full_message_level_action(self): + action = test_pb2.Action() + action.doc_id = "3" + action.number_of_views = 3 + result = _get_full_message_level_runnable([action]) + self.assertAllEqual(result["doc_id"][INDEX], [0]) + self.assertAllEqual(result["doc_id"][VALUE], [b"3"]) + self.assertAllEqual(result["number_of_views"][INDEX], [0]) + self.assertAllEqual(result["number_of_views"][VALUE], [3]) + + def test_parse_message_level(self): + action = test_pb2.Action() + action.doc_id = "3" + action.number_of_views = 3 + tensor_of_protos = tf.constant([action.SerializeToString()]) + [field_tuple] = struct2tensor_ops.parse_message_level( + tensor_of_protos, test_pb2.Action().DESCRIPTOR, ["number_of_views"] + ) + values = field_tuple.value + indices = field_tuple.index + self.assertAllEqual(indices, [0]) + self.assertAllEqual(values, [3]) + + def test_parse_extension(self): + user_info = test_pb2.UserInfo() + user_info.Extensions[ + test_pb2.LocationOfExtension.special_user_info + ].secret = "shhh" + expected_value = test_pb2.SpecialUserInfo() + expected_value.secret = "shhh" + tensor_of_protos = tf.constant([user_info.SerializeToString()]) + [field_tuple] = struct2tensor_ops.parse_message_level( + tensor_of_protos, + test_pb2.UserInfo().DESCRIPTOR, + ["(struct2tensor.test.LocationOfExtension.special_user_info)"], + ) + self.assertAllEqual(field_tuple.index, [0]) + self.assertAllEqual(field_tuple.value, [expected_value.SerializeToString()]) + + def test_parse_external_extension(self): + user_info = test_pb2.UserInfo() + user_info.Extensions[ + test_extension_pb2.MyExternalExtension.ext + ].special = "shhh" + expected_value = test_extension_pb2.MyExternalExtension() + expected_value.special = "shhh" + tensor_of_protos = tf.constant([user_info.SerializeToString()]) + [field_tuple] = struct2tensor_ops.parse_message_level( + tensor_of_protos, + test_pb2.UserInfo().DESCRIPTOR, + ["(struct2tensor.test.MyExternalExtension.ext)"], + ) + self.assertAllEqual(field_tuple.index, [0]) + self.assertAllEqual(field_tuple.value, [expected_value.SerializeToString()]) + + def test_parse_packed_fields(self): + message_with_packed_fields = test_pb2.HasPackedFields( + packed_int32=[-1, -2, -3], + packed_uint32=[100000, 200000, 300000], + packed_int64=[-400000, -500000, -600000], + packed_uint64=[4, 5, 6], + packed_float=[7.0, 8.0, 9.0], + packed_double=[10.0, 11.0, 12.0], + ) + tensor_of_protos = tf.constant( + [message_with_packed_fields.SerializeToString()] * 2 + ) + + parsed_tuples = struct2tensor_ops.parse_message_level( + tensor_of_protos, + test_pb2.HasPackedFields.DESCRIPTOR, + [ + "packed_int32", + "packed_uint32", + "packed_int64", + "packed_uint64", + "packed_float", + "packed_double", + ], + ) + indices = { + parsed_tuple.field_name: parsed_tuple.index + for parsed_tuple in parsed_tuples + } + values = { + parsed_tuple.field_name: parsed_tuple.value + for parsed_tuple in parsed_tuples + } + + for index in indices.values(): + self.assertAllEqual(index, [0, 0, 0, 1, 1, 1]) + for field_name, value in values.items(): self.assertAllEqual( - expected_field_value[f.field_name], f.value, - f"field: {f.field_name}, permutation: {indices}, field_to_parse: {comb}") - - def test_out_of_order_repeated_fields_1(self): - # This is a 2-1-2 wire number pattern. - proto = ( - test_pb2.Event(query_token=["aaa"]).SerializeToString() + - test_pb2.Event(event_id="abc").SerializeToString() + - test_pb2.Event(query_token=["bbb"]).SerializeToString()) - expected_field_value = { - "query_token": [b"aaa", b"bbb"], - "event_id": [b"abc"] - } - - for fields_to_parse in [["query_token"], ["event_id"], - ["query_token", "event_id"]]: - parsed_fields = struct2tensor_ops.parse_message_level( - [proto], test_pb2.Event.DESCRIPTOR, fields_to_parse) - for f in parsed_fields: - self.assertAllEqual(expected_field_value[f.field_name], f.value) - - def test_out_of_order_repeated_fields_2(self): - # This follows a 3-5-3 wire number pattern, where 3 and 4 parsed fields. - proto = ( - test_pb2.Event(query_token=["aaa"]).SerializeToString() + - test_pb2.Event(action_mask=[True]).SerializeToString() + - test_pb2.Event(query_token=["bbb"]).SerializeToString()) - expected_field_value = { - "query_token": [b"aaa", b"bbb"], - "action_mask": [True], - "action": [] - } - for fields_to_parse in [["query_token"], ["action_mask"], - ["query_token", "action_mask"], - ["query_token", "action"], - ["query_token", "action_mask", "action"]]: - parsed_fields = struct2tensor_ops.parse_message_level( - [proto], test_pb2.Event.DESCRIPTOR, fields_to_parse) - for f in parsed_fields: - expected_value = expected_field_value[f.field_name] - if expected_value: - self.assertAllEqual(expected_value, f.value) - - def test_out_of_order_repeated_fields_3(self): - # This follows a 3-5-3 wire number pattern, where 3 and 4 parsed fields. - proto = ( - test_pb2.AllSimple(repeated_string=["aaa"]).SerializeToString() + - test_pb2.AllSimple(repeated_int64=[12345]).SerializeToString() + - test_pb2.AllSimple(repeated_string=["bbb"]).SerializeToString()) - expected_field_value = { - "repeated_string": [b"aaa", b"bbb"], - "repeated_int64": [12345], - "repeated_int32": [], - "repeated_uint32": [] - } - - for fields_to_parse in [["repeated_int64"], ["repeated_string"], - [ - "repeated_string", - "repeated_uint32", "repeated_int32" - ]]: - parsed_fields = struct2tensor_ops.parse_message_level( - [proto], test_pb2.AllSimple.DESCRIPTOR, fields_to_parse) - for f in parsed_fields: - self.assertAllEqual(expected_field_value[f.field_name], f.value) - - def test_parse_full_message_level_for_event(self): - event = test_pb2.Event() - event.event_id = "foo" - event.query = "query" - event.query_token.append("a") - event.query_token.append("b") - action0 = event.action.add() - action0.doc_id = "abc" - action1 = event.action.add() - event.user_info.age_in_years = 38 - event2 = test_pb2.Event() - action2 = event2.action.add() - action2.doc_id = "def" - parsed_field_dict = _parse_full_message_level_as_dict([event, event2]) - doc_id = parsed_field_dict["action"] - serialized_actions = [ - proto.SerializeToString() for proto in [action0, action1, action2] - ] - - self.assertAllEqual(doc_id.index, [0, 0, 1]) - self.assertAllEqual(doc_id.value, serialized_actions) - - def test_parse_full_message_level_for_simple_action_multiple(self): - """Test multiple messages.""" - as1 = test_pb2.AllSimple() - as1.optional_string = "a" - as1.repeated_string.append("b") - as1.repeated_string.append("c") - as2 = test_pb2.AllSimple() - as2.optional_string = "d" - as2.optional_int32 = 123 - as3 = test_pb2.AllSimple() - as3.repeated_string.append("d") - as3.repeated_string.append("e") - as3.optional_int32 = 123 - - parsed_field_dict = _parse_full_message_level_as_dict([as1, as2, as3]) - - doc_id = parsed_field_dict["repeated_string"] - self.assertAllEqual(doc_id.index, [0, 0, 2, 2]) - self.assertAllEqual(doc_id.value, [b"b", b"c", b"d", b"e"]) - - def test_parse_full_message_level_for_all_simple_repeated_repeated(self): - """Test five messages with every possible repeated field repeated.""" - all_simple = test_pb2.AllSimple() - all_simple.repeated_string.append("foo") - all_simple.repeated_string.append("foo2") - all_simple.repeated_int32.append(32) - all_simple.repeated_int32.append(322) - all_simple.repeated_uint32.append(123) - all_simple.repeated_uint32.append(1232) - all_simple.repeated_int64.append(123456) - all_simple.repeated_int64.append(1234562) - all_simple.repeated_uint64.append(123) - all_simple.repeated_uint64.append(1232) - all_simple.repeated_float.append(1.0) - all_simple.repeated_float.append(2.0) - all_simple.repeated_double.append(1.5) - all_simple.repeated_double.append(2.5) - result = _get_full_message_level_runnable([ - all_simple, - test_pb2.AllSimple(), - test_pb2.AllSimple(), all_simple, - test_pb2.AllSimple(), all_simple, - test_pb2.AllSimple() - ]) - self.assertAllEqual(result["repeated_string"][INDEX], [0, 0, 3, 3, 5, 5]) - self.assertAllEqual(result["repeated_string"][VALUE], - [b"foo", b"foo2", b"foo", b"foo2", b"foo", b"foo2"]) - self.assertAllEqual(result["repeated_int32"][INDEX], [0, 0, 3, 3, 5, 5]) - self.assertAllEqual(result["repeated_int32"][VALUE], - [32, 322, 32, 322, 32, 322]) - self.assertAllEqual(result["repeated_uint32"][INDEX], [0, 0, 3, 3, 5, 5]) - self.assertAllEqual(result["repeated_uint32"][VALUE], - [123, 1232, 123, 1232, 123, 1232]) - self.assertAllEqual(result["repeated_int64"][INDEX], [0, 0, 3, 3, 5, 5]) - self.assertAllEqual(result["repeated_int64"][VALUE], - [123456, 1234562, 123456, 1234562, 123456, 1234562]) - self.assertAllEqual(result["repeated_uint64"][INDEX], [0, 0, 3, 3, 5, 5]) - self.assertAllEqual(result["repeated_uint64"][VALUE], - [123, 1232, 123, 1232, 123, 1232]) - self.assertAllEqual(result["repeated_float"][INDEX], [0, 0, 3, 3, 5, 5]) - self.assertAllEqual(result["repeated_float"][VALUE], - [1.0, 2.0, 1.0, 2.0, 1.0, 2.0]) - self.assertAllEqual(result["repeated_double"][INDEX], [0, 0, 3, 3, 5, 5]) - self.assertAllEqual(result["repeated_double"][VALUE], - [1.5, 2.5, 1.5, 2.5, 1.5, 2.5]) - - def test_parse_full_message_level_for_all_simple_repeated(self): - """Test a single message with every possible repeated field repeated.""" - all_simple = test_pb2.AllSimple() - all_simple.repeated_string.append("foo") - all_simple.repeated_string.append("foo2") - all_simple.repeated_int32.append(32) - all_simple.repeated_int32.append(322) - all_simple.repeated_uint32.append(123) - all_simple.repeated_uint32.append(1232) - all_simple.repeated_int64.append(123456) - all_simple.repeated_int64.append(1234562) - all_simple.repeated_uint64.append(123) - all_simple.repeated_uint64.append(1232) - all_simple.repeated_float.append(1.0) - all_simple.repeated_float.append(2.0) - all_simple.repeated_double.append(1.5) - all_simple.repeated_double.append(2.5) - result = _get_full_message_level_runnable([all_simple]) - self.assertAllEqual(result["repeated_string"][INDEX], [0, 0]) - self.assertAllEqual(result["repeated_string"][VALUE], [b"foo", b"foo2"]) - self.assertAllEqual(result["repeated_int32"][INDEX], [0, 0]) - self.assertAllEqual(result["repeated_int32"][VALUE], [32, 322]) - self.assertAllEqual(result["repeated_uint32"][INDEX], [0, 0]) - self.assertAllEqual(result["repeated_uint32"][VALUE], [123, 1232]) - self.assertAllEqual(result["repeated_int64"][INDEX], [0, 0]) - self.assertAllEqual(result["repeated_int64"][VALUE], [123456, 1234562]) - self.assertAllEqual(result["repeated_uint64"][INDEX], [0, 0]) - self.assertAllEqual(result["repeated_uint64"][VALUE], [123, 1232]) - self.assertAllEqual(result["repeated_float"][INDEX], [0, 0]) - self.assertAllEqual(result["repeated_float"][VALUE], [1.0, 2.0]) - self.assertAllEqual(result["repeated_double"][INDEX], [0, 0]) - self.assertAllEqual(result["repeated_double"][VALUE], [1.5, 2.5]) - - def test_parse_full_message_level_for_all_simple(self): - """Test a single message with every possible primitive field.""" - all_simple = test_pb2.AllSimple() - all_simple.optional_string = "foo" - all_simple.optional_int32 = -5 - all_simple.optional_uint32 = 2**31 - all_simple.optional_int64 = 100123 - all_simple.optional_uint64 = 2**63 - all_simple.optional_float = 6.5 - all_simple.optional_double = -7.0 - all_simple.repeated_string.append("foo") - all_simple.repeated_int32.append(32) - all_simple.repeated_uint32.append(123) - all_simple.repeated_int64.append(123456) - all_simple.repeated_uint64.append(123) - all_simple.repeated_float.append(1.0) - all_simple.repeated_double.append(1.5) - runnable = _get_full_message_level_runnable([all_simple]) - self.assertLen(runnable["optional_string"][INDEX].shape.dims, 1) - self.assertLen(runnable["optional_string"][VALUE].shape.dims, 1) - self.assertLen(runnable["repeated_string"][INDEX].shape.dims, 1) - self.assertLen(runnable["repeated_string"][VALUE].shape.dims, 1) - - result = runnable - self.assertAllEqual(result["optional_string"][INDEX], [0]) - self.assertAllEqual(result["optional_string"][VALUE], [b"foo"]) - self.assertAllEqual(result["optional_int32"][INDEX], [0]) - self.assertAllEqual(result["optional_int32"][VALUE], [-5]) - self.assertAllEqual(result["optional_uint32"][INDEX], [0]) - self.assertAllEqual(result["optional_uint32"][VALUE], [2**31]) - self.assertAllEqual(result["optional_int64"][INDEX], [0]) - self.assertAllEqual(result["optional_int64"][VALUE], [100123]) - self.assertAllEqual(result["optional_uint64"][INDEX], [0]) - self.assertAllEqual(result["optional_uint64"][VALUE], [2**63]) - self.assertAllEqual(result["optional_float"][INDEX], [0]) - self.assertAllEqual(result["optional_float"][VALUE], [6.5]) - self.assertAllEqual(result["optional_double"][INDEX], [0]) - self.assertAllEqual(result["optional_double"][VALUE], [-7.0]) - # TODO(martinz): test the repeated fields too. - - def test_parse_full_message_level_action(self): - action = test_pb2.Action() - action.doc_id = "3" - action.number_of_views = 3 - result = _get_full_message_level_runnable([action]) - self.assertAllEqual(result["doc_id"][INDEX], [0]) - self.assertAllEqual(result["doc_id"][VALUE], [b"3"]) - self.assertAllEqual(result["number_of_views"][INDEX], [0]) - self.assertAllEqual(result["number_of_views"][VALUE], [3]) - - def test_parse_message_level(self): - action = test_pb2.Action() - action.doc_id = "3" - action.number_of_views = 3 - tensor_of_protos = tf.constant([action.SerializeToString()]) - [field_tuple - ] = struct2tensor_ops.parse_message_level(tensor_of_protos, - test_pb2.Action().DESCRIPTOR, - ["number_of_views"]) - values = field_tuple.value - indices = field_tuple.index - self.assertAllEqual(indices, [0]) - self.assertAllEqual(values, [3]) - - def test_parse_extension(self): - user_info = test_pb2.UserInfo() - user_info.Extensions[ - test_pb2.LocationOfExtension.special_user_info].secret = "shhh" - expected_value = test_pb2.SpecialUserInfo() - expected_value.secret = "shhh" - tensor_of_protos = tf.constant([user_info.SerializeToString()]) - [field_tuple] = struct2tensor_ops.parse_message_level( - tensor_of_protos, - test_pb2.UserInfo().DESCRIPTOR, - ["(struct2tensor.test.LocationOfExtension.special_user_info)"]) - self.assertAllEqual(field_tuple.index, [0]) - self.assertAllEqual(field_tuple.value, [expected_value.SerializeToString()]) - - def test_parse_external_extension(self): - user_info = test_pb2.UserInfo() - user_info.Extensions[ - test_extension_pb2.MyExternalExtension.ext].special = "shhh" - expected_value = test_extension_pb2.MyExternalExtension() - expected_value.special = "shhh" - tensor_of_protos = tf.constant([user_info.SerializeToString()]) - [field_tuple] = struct2tensor_ops.parse_message_level( - tensor_of_protos, - test_pb2.UserInfo().DESCRIPTOR, - ["(struct2tensor.test.MyExternalExtension.ext)"]) - self.assertAllEqual(field_tuple.index, [0]) - self.assertAllEqual(field_tuple.value, [expected_value.SerializeToString()]) - - - def test_parse_packed_fields(self): - message_with_packed_fields = test_pb2.HasPackedFields( - packed_int32=[-1, -2, -3], - packed_uint32=[100000, 200000, 300000], - packed_int64=[-400000, -500000, -600000], - packed_uint64=[4, 5, 6], - packed_float=[7.0, 8.0, 9.0], - packed_double=[10.0, 11.0, 12.0], - ) - tensor_of_protos = tf.constant( - [message_with_packed_fields.SerializeToString()] * 2) - - parsed_tuples = struct2tensor_ops.parse_message_level( - tensor_of_protos, test_pb2.HasPackedFields.DESCRIPTOR, [ - "packed_int32", - "packed_uint32", - "packed_int64", - "packed_uint64", - "packed_float", - "packed_double", - ]) - indices = { - parsed_tuple.field_name: parsed_tuple.index - for parsed_tuple in parsed_tuples - } - values = { - parsed_tuple.field_name: parsed_tuple.value - for parsed_tuple in parsed_tuples - } - - for index in indices.values(): - self.assertAllEqual(index, [0, 0, 0, 1, 1, 1]) - for field_name, value in values.items(): - self.assertAllEqual( - value, - list(getattr(message_with_packed_fields, field_name)) * 2) - - - def test_proto2_optional_field_with_honor_proto3_optional_semantic(self): - proto2_message1 = test_pb2.AllSimple() - proto2_message2 = test_pb2.AllSimple(optional_string="a") - tensor_of_protos = tf.constant([proto2_message1.SerializeToString(), - proto2_message2.SerializeToString(), - proto2_message1.SerializeToString()]) - parsed_tuples = struct2tensor_ops.parse_message_level( - tensor_of_protos, test_pb2.AllSimple.DESCRIPTOR, [ - "optional_string", - ], honor_proto3_optional_semantics=True) - indices = { - parsed_tuple.field_name: parsed_tuple.index - for parsed_tuple in parsed_tuples - } - values = { - parsed_tuple.field_name: parsed_tuple.value - for parsed_tuple in parsed_tuples - } - # Only the second proto has value. No default value should be inserted. - for idx in indices.values(): - self.assertAllEqual([1], idx) - for value in values.values(): - self.assertAllEqual([b"a"], value) - - def test_make_repeated_basic(self): - parent_index = tf.constant([0, 0, 4, 4, 4, 7, 8, 9], dtype=tf.int64) - values = tf.constant(["a", "b", "c", "d", "e", "f", "g", "h"]) - sparse_tensor = struct2tensor_ops.create_sparse_tensor_for_repeated( - parent_index, values, tf.constant([10, 3], dtype=tf.int64)) - self.assertAllEqual( - sparse_tensor.indices, - [[0, 0], [0, 1], [4, 0], [4, 1], [4, 2], [7, 0], [8, 0], [9, 0]]) - - def test_make_repeated_empty(self): - parent_index = tf.constant([], dtype=tf.int64) - values = tf.constant([], dtype=tf.int32) - sparse_tensor = struct2tensor_ops.create_sparse_tensor_for_repeated( - parent_index, values, tf.constant([0, 0], dtype=tf.int64)) - self.assertAllEqual(sparse_tensor.indices.shape, [0, 2]) - - @parameterized.named_parameters(*_EQUIJOIN_TEST_CASES) - def test_equi_join_indices(self, a, b, expected_index_a, expected_index_b): - a = tf.constant(a, dtype=tf.int64) - b = tf.constant(b, dtype=tf.int64) - - # Test equi_join_indices - [index_a, index_b] = struct2tensor_ops.equi_join_indices(a, b) - self.assertAllEqual(index_a, expected_index_a) - self.assertAllEqual(index_b, expected_index_b) - - # Test equi_join_any_indices - [index_a, index_b] = struct2tensor_ops.equi_join_any_indices(a, b) - self.assertAllEqual(index_a, expected_index_a) - self.assertAllEqual(index_b, expected_index_b) - - def test_equi_join_any_indices_non_monotonic(self): - a = tf.constant([0, 1, 2, 1, 2], dtype=tf.int64) - b = tf.constant([0, 1, 1, 2, 3], dtype=tf.int64) - [index_a, index_b] = struct2tensor_ops.equi_join_any_indices(a, b) - self.assertAllEqual(index_a, [0, 1, 1, 2, 3, 3, 4]) - self.assertAllEqual(index_b, [0, 1, 2, 3, 1, 2, 3]) - - def test_run_length_before(self): - """Breaking down the broadcast.""" - a = tf.constant([0, 1, 1, 7, 8, 8, 9], dtype=tf.int64) - b = struct2tensor_ops.run_length_before(a) - self.assertAllEqual(b, [0, 0, 1, 0, 0, 1, 0]) - - def test_run_length_before_empty(self): - """Breaking down the broadcast.""" - a = tf.constant([], dtype=tf.int64) - b = struct2tensor_ops.run_length_before(a) - self.assertAllEqual(b, []) - -_SIGNED_INTEGER_TYPES = [ - "int32", "int64", "sfixed32", "sfixed64", "sint32", "sint64" -] + value, list(getattr(message_with_packed_fields, field_name)) * 2 + ) + + def test_proto2_optional_field_with_honor_proto3_optional_semantic(self): + proto2_message1 = test_pb2.AllSimple() + proto2_message2 = test_pb2.AllSimple(optional_string="a") + tensor_of_protos = tf.constant( + [ + proto2_message1.SerializeToString(), + proto2_message2.SerializeToString(), + proto2_message1.SerializeToString(), + ] + ) + parsed_tuples = struct2tensor_ops.parse_message_level( + tensor_of_protos, + test_pb2.AllSimple.DESCRIPTOR, + [ + "optional_string", + ], + honor_proto3_optional_semantics=True, + ) + indices = { + parsed_tuple.field_name: parsed_tuple.index + for parsed_tuple in parsed_tuples + } + values = { + parsed_tuple.field_name: parsed_tuple.value + for parsed_tuple in parsed_tuples + } + # Only the second proto has value. No default value should be inserted. + for idx in indices.values(): + self.assertAllEqual([1], idx) + for value in values.values(): + self.assertAllEqual([b"a"], value) + + def test_make_repeated_basic(self): + parent_index = tf.constant([0, 0, 4, 4, 4, 7, 8, 9], dtype=tf.int64) + values = tf.constant(["a", "b", "c", "d", "e", "f", "g", "h"]) + sparse_tensor = struct2tensor_ops.create_sparse_tensor_for_repeated( + parent_index, values, tf.constant([10, 3], dtype=tf.int64) + ) + self.assertAllEqual( + sparse_tensor.indices, + [[0, 0], [0, 1], [4, 0], [4, 1], [4, 2], [7, 0], [8, 0], [9, 0]], + ) + + def test_make_repeated_empty(self): + parent_index = tf.constant([], dtype=tf.int64) + values = tf.constant([], dtype=tf.int32) + sparse_tensor = struct2tensor_ops.create_sparse_tensor_for_repeated( + parent_index, values, tf.constant([0, 0], dtype=tf.int64) + ) + self.assertAllEqual(sparse_tensor.indices.shape, [0, 2]) + + @parameterized.named_parameters(*_EQUIJOIN_TEST_CASES) + def test_equi_join_indices(self, a, b, expected_index_a, expected_index_b): + a = tf.constant(a, dtype=tf.int64) + b = tf.constant(b, dtype=tf.int64) + + # Test equi_join_indices + [index_a, index_b] = struct2tensor_ops.equi_join_indices(a, b) + self.assertAllEqual(index_a, expected_index_a) + self.assertAllEqual(index_b, expected_index_b) + + # Test equi_join_any_indices + [index_a, index_b] = struct2tensor_ops.equi_join_any_indices(a, b) + self.assertAllEqual(index_a, expected_index_a) + self.assertAllEqual(index_b, expected_index_b) + + def test_equi_join_any_indices_non_monotonic(self): + a = tf.constant([0, 1, 2, 1, 2], dtype=tf.int64) + b = tf.constant([0, 1, 1, 2, 3], dtype=tf.int64) + [index_a, index_b] = struct2tensor_ops.equi_join_any_indices(a, b) + self.assertAllEqual(index_a, [0, 1, 1, 2, 3, 3, 4]) + self.assertAllEqual(index_b, [0, 1, 2, 3, 1, 2, 3]) + + def test_run_length_before(self): + """Breaking down the broadcast.""" + a = tf.constant([0, 1, 1, 7, 8, 8, 9], dtype=tf.int64) + b = struct2tensor_ops.run_length_before(a) + self.assertAllEqual(b, [0, 0, 1, 0, 0, 1, 0]) + + def test_run_length_before_empty(self): + """Breaking down the broadcast.""" + a = tf.constant([], dtype=tf.int64) + b = struct2tensor_ops.run_length_before(a) + self.assertAllEqual(b, []) + + +_SIGNED_INTEGER_TYPES = ["int32", "int64", "sfixed32", "sfixed64", "sint32", "sint64"] _UNSIGNED_INTEGER_TYPES = ["uint32", "uint64", "fixed32", "fixed64"] @test_util.run_all_in_graph_and_eager_modes class DecodeProtoMapOpTest(parameterized.TestCase, tf.test.TestCase): - - def _parse_map_entry(self, messages_with_map, map_field_name, keys_needed): - parsed_map_submessage = struct2tensor_ops.parse_message_level( - tf.constant([m.SerializeToString() for m in messages_with_map]), - test_map_pb2.MessageWithMap.DESCRIPTOR, [map_field_name])[0] - - return struct2tensor_ops.parse_proto_map( - parsed_map_submessage.value, parsed_map_submessage.index, - parsed_map_submessage.field_descriptor.message_type, keys_needed) - - @parameterized.named_parameters( - [dict(testcase_name=t, key_type=t) for t in _SIGNED_INTEGER_TYPES]) - def test_signed_integer_key_types(self, key_type): - field_name = f"{key_type}_string_map" - message_with_map = test_map_pb2.MessageWithMap() - map_entry = getattr(message_with_map, f"{key_type}_string_map") - map_entry[42] = "hello" - map_entry[-42] = "world" - - [(values_42, indices_42), (values_n42, indices_n42), - (values_0, indices_0)] = self._parse_map_entry([message_with_map], - field_name, - ["42", "-42", "0"]) - - self.assertAllEqual(values_42, [b"hello"]) - self.assertAllEqual(values_n42, [b"world"]) - self.assertAllEqual(values_0, []) - self.assertAllEqual(indices_42, [0]) - self.assertAllEqual(indices_n42, [0]) - self.assertAllEqual(indices_0, []) - - @parameterized.named_parameters( - [dict(testcase_name=t, key_type=t) for t in _UNSIGNED_INTEGER_TYPES]) - def test_unsigned_integer_key_types(self, key_type): - field_name = f"{key_type}_string_map" - message_with_map = test_map_pb2.MessageWithMap() - map_entry = getattr(message_with_map, f"{key_type}_string_map") - map_entry[42] = "hello" - - [(values_42, indices_42), - (values_0, indices_0)] = self._parse_map_entry([message_with_map], - field_name, ["42", "0"]) - self.assertAllEqual(values_42, [b"hello"]) - self.assertAllEqual(values_0, []) - self.assertAllEqual(indices_42, [0]) - self.assertAllEqual(indices_0, []) - - def test_invalid_uint32_key(self): - with self.assertRaisesRegex( - tf.errors.InvalidArgumentError, "Failed to parse .*string" - ): - self.evaluate( - self._parse_map_entry([test_map_pb2.MessageWithMap()], - "uint32_string_map", ["-42"])) - - def test_invalid_int32_key(self): - with self.assertRaisesRegex( - tf.errors.InvalidArgumentError, "Failed to parse .*string" - ): - self.evaluate( - self._parse_map_entry([test_map_pb2.MessageWithMap()], - "int32_string_map", ["foo"])) - - def test_bool_key_type(self): - message_with_map = test_map_pb2.MessageWithMap() - message_with_map.bool_string_map[False] = "hello" - [(values_false, indices_false), (values_true, indices_true) - ] = self._parse_map_entry([message_with_map], "bool_string_map", ["0", "1"]) - self.assertAllEqual(values_true, []) - self.assertAllEqual(values_false, [b"hello"]) - self.assertAllEqual(indices_true, []) - self.assertAllEqual(indices_false, [0]) - - def test_invalid_bool_key(self): - message_with_map = test_map_pb2.MessageWithMap() - with self.assertRaisesRegex( - tf.errors.InvalidArgumentError, "Failed to parse .*string" - ): - self.evaluate( - self._parse_map_entry([message_with_map], "bool_string_map", ["2"])) - - @parameterized.named_parameters( - [dict(testcase_name=t, value_type=t) for t in _SIGNED_INTEGER_TYPES]) - def test_signed_integer_value_types(self, value_type): - field_name = f"string_{value_type}_map" - message_with_map = test_map_pb2.MessageWithMap() - map_entry = getattr(message_with_map, f"string_{value_type}_map") - map_entry["foo"] = 42 - map_entry["bar"] = -42 - [(values_foo, indices_foo), (values_bar, indices_bar), - (values_null, indices_null)] = self._parse_map_entry([message_with_map], - field_name, - ["foo", "bar", ""]) - self.assertAllEqual(values_foo, [42]) - self.assertAllEqual(values_bar, [-42]) - self.assertAllEqual(values_null, []) - self.assertAllEqual(indices_foo, [0]) - self.assertAllEqual(indices_bar, [0]) - self.assertAllEqual(indices_null, []) - - @parameterized.named_parameters( - [dict(testcase_name=t, value_type=t) for t in _UNSIGNED_INTEGER_TYPES]) - def test_unsigned_integer_value_types(self, value_type): - field_name = f"string_{value_type}_map" - message_with_map = test_map_pb2.MessageWithMap() - map_entry = getattr(message_with_map, f"string_{value_type}_map") - map_entry["foo"] = 42 - [(values_foo, indices_foo), (values_null, indices_null) - ] = self._parse_map_entry([message_with_map], field_name, ["foo", ""]) - self.assertAllEqual(values_foo, [42]) - self.assertAllEqual(values_null, []) - self.assertAllEqual(indices_foo, [0]) - self.assertAllEqual(indices_null, []) - - @parameterized.named_parameters( - [dict(testcase_name=t, value_type=t) for t in ["float", "double"]]) - def test_fp_value_types(self, value_type): - field_name = f"string_{value_type}_map" - message_with_map = test_map_pb2.MessageWithMap() - map_entry = getattr(message_with_map, f"string_{value_type}_map") - map_entry["foo"] = 0.5 - [(values_foo, indices_foo), (values_null, indices_null) - ] = self._parse_map_entry([message_with_map], field_name, ["foo", ""]) - self.assertAllEqual(values_foo, [0.5]) - self.assertAllEqual(values_null, []) - self.assertAllEqual(indices_foo, [0]) - self.assertAllEqual(indices_null, []) - - def test_enum_value_type(self): - message_with_map = test_map_pb2.MessageWithMap() - message_with_map.string_enum_map["foo"] = test_map_pb2.BAZ - [(values_foo, indices_foo), - (values_null, indices_null)] = self._parse_map_entry([message_with_map], - "string_enum_map", - ["foo", ""]) - self.assertAllEqual(values_foo, [int(test_map_pb2.BAZ)]) - self.assertAllEqual(values_null, []) - self.assertAllEqual(indices_foo, [0]) - self.assertAllEqual(indices_null, []) - - def test_message_value_type(self): - sub_message = test_map_pb2.SubMessage(repeated_int64=[1, 2, 3]) - message_with_map = test_map_pb2.MessageWithMap() - message_with_map.string_message_map["foo"].MergeFrom(sub_message) - [(values_foo, indices_foo), - (values_null, indices_null)] = self._parse_map_entry([message_with_map], - "string_message_map", - ["foo", ""]) - self.assertAllEqual(values_foo, [sub_message.SerializeToString()]) - self.assertAllEqual(values_null, []) - self.assertAllEqual(indices_foo, [0]) - self.assertAllEqual(indices_null, []) - - def test_multiple_messages(self): - message_with_map1 = test_map_pb2.MessageWithMap(string_string_map={ - "key1": "foo", - "key3": "bar" - }) - message_with_map2 = test_map_pb2.MessageWithMap() - message_with_map3 = test_map_pb2.MessageWithMap(string_string_map={ - "key2": "baz", - "key1": "kaz" - }) - [(values_key1, indices_key1), (values_key2, indices_key2), - (values_key3, indices_key3)] = self._parse_map_entry( - [message_with_map1, message_with_map2, message_with_map3], - "string_string_map", ["key1", "key2", "key3"]) - self.assertAllEqual(values_key1, [b"foo", b"kaz"]) - self.assertAllEqual(values_key2, [b"baz"]) - self.assertAllEqual(values_key3, [b"bar"]) - self.assertAllEqual(indices_key1, [0, 2]) - self.assertAllEqual(indices_key2, [2]) - self.assertAllEqual(indices_key3, [0]) - - def test_corrupted_message(self): - with self.assertRaises(tf.errors.DataLossError): - self.evaluate( - struct2tensor_ops.parse_proto_map( - tf.constant(["corrupted message"]), - tf.constant([0], dtype=tf.int64), test_map_pb2.MessageWithMap - .DESCRIPTOR.fields_by_name["int32_string_map"].message_type, - ["0"])) + def _parse_map_entry(self, messages_with_map, map_field_name, keys_needed): + parsed_map_submessage = struct2tensor_ops.parse_message_level( + tf.constant([m.SerializeToString() for m in messages_with_map]), + test_map_pb2.MessageWithMap.DESCRIPTOR, + [map_field_name], + )[0] + + return struct2tensor_ops.parse_proto_map( + parsed_map_submessage.value, + parsed_map_submessage.index, + parsed_map_submessage.field_descriptor.message_type, + keys_needed, + ) + + @parameterized.named_parameters( + [dict(testcase_name=t, key_type=t) for t in _SIGNED_INTEGER_TYPES] + ) + def test_signed_integer_key_types(self, key_type): + field_name = f"{key_type}_string_map" + message_with_map = test_map_pb2.MessageWithMap() + map_entry = getattr(message_with_map, f"{key_type}_string_map") + map_entry[42] = "hello" + map_entry[-42] = "world" + + [(values_42, indices_42), (values_n42, indices_n42), (values_0, indices_0)] = ( + self._parse_map_entry([message_with_map], field_name, ["42", "-42", "0"]) + ) + + self.assertAllEqual(values_42, [b"hello"]) + self.assertAllEqual(values_n42, [b"world"]) + self.assertAllEqual(values_0, []) + self.assertAllEqual(indices_42, [0]) + self.assertAllEqual(indices_n42, [0]) + self.assertAllEqual(indices_0, []) + + @parameterized.named_parameters( + [dict(testcase_name=t, key_type=t) for t in _UNSIGNED_INTEGER_TYPES] + ) + def test_unsigned_integer_key_types(self, key_type): + field_name = f"{key_type}_string_map" + message_with_map = test_map_pb2.MessageWithMap() + map_entry = getattr(message_with_map, f"{key_type}_string_map") + map_entry[42] = "hello" + + [(values_42, indices_42), (values_0, indices_0)] = self._parse_map_entry( + [message_with_map], field_name, ["42", "0"] + ) + self.assertAllEqual(values_42, [b"hello"]) + self.assertAllEqual(values_0, []) + self.assertAllEqual(indices_42, [0]) + self.assertAllEqual(indices_0, []) + + def test_invalid_uint32_key(self): + with self.assertRaisesRegex( + tf.errors.InvalidArgumentError, "Failed to parse .*string" + ): + self.evaluate( + self._parse_map_entry( + [test_map_pb2.MessageWithMap()], "uint32_string_map", ["-42"] + ) + ) + + def test_invalid_int32_key(self): + with self.assertRaisesRegex( + tf.errors.InvalidArgumentError, "Failed to parse .*string" + ): + self.evaluate( + self._parse_map_entry( + [test_map_pb2.MessageWithMap()], "int32_string_map", ["foo"] + ) + ) + + def test_bool_key_type(self): + message_with_map = test_map_pb2.MessageWithMap() + message_with_map.bool_string_map[False] = "hello" + [(values_false, indices_false), (values_true, indices_true)] = ( + self._parse_map_entry([message_with_map], "bool_string_map", ["0", "1"]) + ) + self.assertAllEqual(values_true, []) + self.assertAllEqual(values_false, [b"hello"]) + self.assertAllEqual(indices_true, []) + self.assertAllEqual(indices_false, [0]) + + def test_invalid_bool_key(self): + message_with_map = test_map_pb2.MessageWithMap() + with self.assertRaisesRegex( + tf.errors.InvalidArgumentError, "Failed to parse .*string" + ): + self.evaluate( + self._parse_map_entry([message_with_map], "bool_string_map", ["2"]) + ) + + @parameterized.named_parameters( + [dict(testcase_name=t, value_type=t) for t in _SIGNED_INTEGER_TYPES] + ) + def test_signed_integer_value_types(self, value_type): + field_name = f"string_{value_type}_map" + message_with_map = test_map_pb2.MessageWithMap() + map_entry = getattr(message_with_map, f"string_{value_type}_map") + map_entry["foo"] = 42 + map_entry["bar"] = -42 + [ + (values_foo, indices_foo), + (values_bar, indices_bar), + (values_null, indices_null), + ] = self._parse_map_entry([message_with_map], field_name, ["foo", "bar", ""]) + self.assertAllEqual(values_foo, [42]) + self.assertAllEqual(values_bar, [-42]) + self.assertAllEqual(values_null, []) + self.assertAllEqual(indices_foo, [0]) + self.assertAllEqual(indices_bar, [0]) + self.assertAllEqual(indices_null, []) + + @parameterized.named_parameters( + [dict(testcase_name=t, value_type=t) for t in _UNSIGNED_INTEGER_TYPES] + ) + def test_unsigned_integer_value_types(self, value_type): + field_name = f"string_{value_type}_map" + message_with_map = test_map_pb2.MessageWithMap() + map_entry = getattr(message_with_map, f"string_{value_type}_map") + map_entry["foo"] = 42 + [(values_foo, indices_foo), (values_null, indices_null)] = ( + self._parse_map_entry([message_with_map], field_name, ["foo", ""]) + ) + self.assertAllEqual(values_foo, [42]) + self.assertAllEqual(values_null, []) + self.assertAllEqual(indices_foo, [0]) + self.assertAllEqual(indices_null, []) + + @parameterized.named_parameters( + [dict(testcase_name=t, value_type=t) for t in ["float", "double"]] + ) + def test_fp_value_types(self, value_type): + field_name = f"string_{value_type}_map" + message_with_map = test_map_pb2.MessageWithMap() + map_entry = getattr(message_with_map, f"string_{value_type}_map") + map_entry["foo"] = 0.5 + [(values_foo, indices_foo), (values_null, indices_null)] = ( + self._parse_map_entry([message_with_map], field_name, ["foo", ""]) + ) + self.assertAllEqual(values_foo, [0.5]) + self.assertAllEqual(values_null, []) + self.assertAllEqual(indices_foo, [0]) + self.assertAllEqual(indices_null, []) + + def test_enum_value_type(self): + message_with_map = test_map_pb2.MessageWithMap() + message_with_map.string_enum_map["foo"] = test_map_pb2.BAZ + [(values_foo, indices_foo), (values_null, indices_null)] = ( + self._parse_map_entry([message_with_map], "string_enum_map", ["foo", ""]) + ) + self.assertAllEqual(values_foo, [int(test_map_pb2.BAZ)]) + self.assertAllEqual(values_null, []) + self.assertAllEqual(indices_foo, [0]) + self.assertAllEqual(indices_null, []) + + def test_message_value_type(self): + sub_message = test_map_pb2.SubMessage(repeated_int64=[1, 2, 3]) + message_with_map = test_map_pb2.MessageWithMap() + message_with_map.string_message_map["foo"].MergeFrom(sub_message) + [(values_foo, indices_foo), (values_null, indices_null)] = ( + self._parse_map_entry([message_with_map], "string_message_map", ["foo", ""]) + ) + self.assertAllEqual(values_foo, [sub_message.SerializeToString()]) + self.assertAllEqual(values_null, []) + self.assertAllEqual(indices_foo, [0]) + self.assertAllEqual(indices_null, []) + + def test_multiple_messages(self): + message_with_map1 = test_map_pb2.MessageWithMap( + string_string_map={"key1": "foo", "key3": "bar"} + ) + message_with_map2 = test_map_pb2.MessageWithMap() + message_with_map3 = test_map_pb2.MessageWithMap( + string_string_map={"key2": "baz", "key1": "kaz"} + ) + [ + (values_key1, indices_key1), + (values_key2, indices_key2), + (values_key3, indices_key3), + ] = self._parse_map_entry( + [message_with_map1, message_with_map2, message_with_map3], + "string_string_map", + ["key1", "key2", "key3"], + ) + self.assertAllEqual(values_key1, [b"foo", b"kaz"]) + self.assertAllEqual(values_key2, [b"baz"]) + self.assertAllEqual(values_key3, [b"bar"]) + self.assertAllEqual(indices_key1, [0, 2]) + self.assertAllEqual(indices_key2, [2]) + self.assertAllEqual(indices_key3, [0]) + + def test_corrupted_message(self): + with self.assertRaises(tf.errors.DataLossError): + self.evaluate( + struct2tensor_ops.parse_proto_map( + tf.constant(["corrupted message"]), + tf.constant([0], dtype=tf.int64), + test_map_pb2.MessageWithMap.DESCRIPTOR.fields_by_name[ + "int32_string_map" + ].message_type, + ["0"], + ) + ) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/path.py b/struct2tensor/path.py index ec3824e..57e07bb 100644 --- a/struct2tensor/path.py +++ b/struct2tensor/path.py @@ -41,257 +41,272 @@ def get_anonymous_field() -> AnonymousId: - """Gets a globally unique anonymous field.""" - # TODO(martinz): Add thread safety here. - global _NEXT_ANONYMOUS_FIELD - result = _NEXT_ANONYMOUS_FIELD - _NEXT_ANONYMOUS_FIELD += 1 - return result + """Gets a globally unique anonymous field.""" + # TODO(martinz): Add thread safety here. + global _NEXT_ANONYMOUS_FIELD + result = _NEXT_ANONYMOUS_FIELD + _NEXT_ANONYMOUS_FIELD += 1 + return result def _compare_step(a: Step, b: Step) -> int: - """Return positive, zero, or negative if a>b, a==b, or ab, a==b, or a b: - return 1 - if a < b: - return -1 - return 0 - - if aint: - return 1 - return -1 - - -class Path: - """A representation of a path in the expression. + """Return positive, zero, or negative if a>b, a==b, or a int: # pytype: disable=signature-mismatch # overriding-return-type-checks - """Lexicographical ordering of paths. - - If one path is a strict prefix of the other, the prefix is less. - Otherwise, the prefix of the longer path is compared to the shorter path. + AnonymousIds are greater than string values. Args: - other: the path to compare to. + a: A step. + b: A step. Returns: - -1, 0, or 1 if this is <,==, or > other + positive, zero, or negative if a>b, a==b, or a 0: - return step_diff - if step_diff < 0: - return step_diff - len_diff = len(self.field_list) - len(other.field_list) - if len_diff > 0: - return 1 - if len_diff < 0: - return -1 - return 0 - - def __eq__(self, other: "Path") -> bool: - return self.__cmp__(other) == 0 + aint = isinstance(a, AnonymousId) + bint = isinstance(b, AnonymousId) - def __ne__(self, other: "Path") -> bool: - return self.__cmp__(other) != 0 + if aint == bint: + if a > b: + return 1 + if a < b: + return -1 + return 0 - def __le__(self, other: "Path") -> bool: - return self.__cmp__(other) <= 0 + if aint: + return 1 + return -1 - def __lt__(self, other: "Path") -> bool: - return self.__cmp__(other) < 0 - def __ge__(self, other: "Path") -> bool: - return self.__cmp__(other) >= 0 - - def __gt__(self, other: "Path") -> bool: - return self.__cmp__(other) > 0 - - def __hash__(self) -> int: - return hash(self.field_list) - - def get_parent(self) -> "Path": - """Get the parent path. +class Path: + """A representation of a path in the expression. - Returns: - The parent path. + Do not implement __nonzero__, __eq__, __ne__, et cetera as these are + implicitly defined by __cmp__ and __len__. - Raises: - ValueError: If this is the root path. """ - if not self: - raise ValueError("Tried to find parent of root") - return Path(self.field_list[:-1], - validate_step_format=self._validate_step_format) - - def get_child(self, field_name: Step) -> "Path": - """Get the child path.""" - if ( - isinstance(field_name, str) - and self._validate_step_format - and not is_valid_step(field_name) - ): - raise ValueError("field_name is not valid: " + field_name) - return Path( - self.field_list + (field_name,), - validate_step_format=self._validate_step_format, - ) - def concat(self, other_path: "Path") -> "Path": - return Path( - self.field_list + other_path.field_list, - validate_step_format=self._validate_step_format, - ) + def __init__(self, field_list: Sequence[Step], validate_step_format=True): + """Create a path object. + + Args: + field_list: a list or tuple of fields leading from one node to another. + validate_step_format: If True, validates that steps do not have any + characters that could be ambiguously understood as structure delimiters + (e.g. "."). If False, such characters are allowed and the client is + responsible to ensure do not rely on any auto-coercion of strings to + paths. + + Raises: + ValueError: if any field is not a valid step (see is_valid_step). + """ + for field in field_list: + if ( + isinstance(field, str) + and validate_step_format + and not is_valid_step(field) + ): + raise ValueError('Field "' + field + '" is invalid.') + self.field_list = tuple(field_list) + self._validate_step_format = validate_step_format + + def __cmp__( + self, other: "Path" + ) -> int: # pytype: disable=signature-mismatch # overriding-return-type-checks + """Lexicographical ordering of paths. + + If one path is a strict prefix of the other, the prefix is less. + Otherwise, the prefix of the longer path is compared to the shorter path. + + Args: + other: the path to compare to. + + Returns: + -1, 0, or 1 if this is <,==, or > other + """ + for i in range(min(len(self.field_list), len(other.field_list))): + step_diff = _compare_step(self.field_list[i], other.field_list[i]) + if step_diff > 0: + return step_diff + if step_diff < 0: + return step_diff + len_diff = len(self.field_list) - len(other.field_list) + if len_diff > 0: + return 1 + if len_diff < 0: + return -1 + return 0 + + def __eq__(self, other: "Path") -> bool: + return self.__cmp__(other) == 0 + + def __ne__(self, other: "Path") -> bool: + return self.__cmp__(other) != 0 + + def __le__(self, other: "Path") -> bool: + return self.__cmp__(other) <= 0 + + def __lt__(self, other: "Path") -> bool: + return self.__cmp__(other) < 0 + + def __ge__(self, other: "Path") -> bool: + return self.__cmp__(other) >= 0 + + def __gt__(self, other: "Path") -> bool: + return self.__cmp__(other) > 0 + + def __hash__(self) -> int: + return hash(self.field_list) + + def get_parent(self) -> "Path": + """Get the parent path. + + Returns: + The parent path. + + Raises: + ValueError: If this is the root path. + """ + if not self: + raise ValueError("Tried to find parent of root") + return Path( + self.field_list[:-1], validate_step_format=self._validate_step_format + ) + + def get_child(self, field_name: Step) -> "Path": + """Get the child path.""" + if ( + isinstance(field_name, str) + and self._validate_step_format + and not is_valid_step(field_name) + ): + raise ValueError("field_name is not valid: " + field_name) + return Path( + self.field_list + (field_name,), + validate_step_format=self._validate_step_format, + ) + + def concat(self, other_path: "Path") -> "Path": + return Path( + self.field_list + other_path.field_list, + validate_step_format=self._validate_step_format, + ) + + def prefix(self, ending_index: int) -> "Path": + return Path( + self.field_list[:ending_index], + validate_step_format=self._validate_step_format, + ) + + def suffix(self, starting_index: int) -> "Path": + return Path( + self.field_list[starting_index:], + validate_step_format=self._validate_step_format, + ) + + def __len__(self) -> int: + return len(self.field_list) + + def _get_least_common_ancestor_len(self, other: "Path") -> int: + """Get the length of the LCA path (the longest shared prefix).""" + min_length = min(len(self.field_list), len(other.field_list)) + for i in range(min_length): + if self.field_list[i] != other.field_list[i]: + return i + return min_length + + def get_least_common_ancestor(self, other: "Path") -> "Path": + """Get the least common ancestor, the longest shared prefix.""" + lca_len = self._get_least_common_ancestor_len(other) + return Path( + self.field_list[:lca_len], + validate_step_format=self._validate_step_format, + ) + + def is_ancestor(self, other: "Path") -> bool: + """True if self is ancestor of other (i.e. a prefix).""" + return len(self.field_list) <= len(other.field_list) and self == Path( + other.field_list[: len(self.field_list)] + ) + + def as_proto(self): + """Serialize a path as a proto. + + This fails if there are any anonymous fields. + + Returns: + a Path proto. + """ + result = tf_metadata_path_pb2.Path() + for x in self.field_list: + if isinstance(x, str): + result.step.append(x) + elif isinstance(x, AnonymousId): + raise ValueError("Cannot serialize a path with anonymous fields") + else: + raise ValueError(f"Unexpected path element type: {type(x)}") + return result + + def __repr__(self) -> str: + """Formats Path for pprint.""" + return str(self) + + def __str__(self) -> str: + """Get a string representation of this path. + + Note that if some fields have periods in them, then: + create_path(str(path)) != path + + Returns: + A string representation of the path, using periods. + """ + return ".".join([str(x) for x in self.field_list]) + + def __add__(self, other: Union["Path", str]) -> "Path": + if isinstance(other, str): + other = create_path(other) + return self.concat(other) - def prefix(self, ending_index: int) -> "Path": - return Path( - self.field_list[:ending_index], - validate_step_format=self._validate_step_format, - ) - def suffix(self, starting_index: int) -> "Path": - return Path( - self.field_list[starting_index:], - validate_step_format=self._validate_step_format, - ) - - def __len__(self) -> int: - return len(self.field_list) - - def _get_least_common_ancestor_len(self, other: "Path") -> int: - """Get the length of the LCA path (the longest shared prefix).""" - min_length = min(len(self.field_list), len(other.field_list)) - for i in range(min_length): - if self.field_list[i] != other.field_list[i]: - return i - return min_length - - def get_least_common_ancestor(self, other: "Path") -> "Path": - """Get the least common ancestor, the longest shared prefix.""" - lca_len = self._get_least_common_ancestor_len(other) - return Path( - self.field_list[:lca_len], - validate_step_format=self._validate_step_format, +def is_valid_step(step_str: str) -> bool: + """Return true if step_str is a valid step (see create_path).""" + return ( + re.match( + "(?:" + + _EXTENSION_REGEX + + "|" + + _SIMPLE_STEP_REGEX + + "|" + + _MAP_INDEXING_STEP_REGEX + + ")$", + step_str, + re.VERBOSE, + ) + is not None ) - def is_ancestor(self, other: "Path") -> bool: - """True if self is ancestor of other (i.e. a prefix).""" - return len(self.field_list) <= len(other.field_list) and self == Path( - other.field_list[:len(self.field_list)]) - - def as_proto(self): - """Serialize a path as a proto. - This fails if there are any anonymous fields. - - Returns: - a Path proto. - """ - result = tf_metadata_path_pb2.Path() - for x in self.field_list: - if isinstance(x, str): - result.step.append(x) - elif isinstance(x, AnonymousId): - raise ValueError("Cannot serialize a path with anonymous fields") - else: - raise ValueError(f"Unexpected path element type: {type(x)}") - return result - - def __repr__(self) -> str: - """Formats Path for pprint.""" - return str(self) - - def __str__(self) -> str: - """Get a string representation of this path. +def is_extension(step_str: str) -> bool: + """Return true if step_str is an extension or Any. - Note that if some fields have periods in them, then: - create_path(str(path)) != path + Args: + step_str: the string to evaluate Returns: - A string representation of the path, using periods. + True if step_str is an extension + Raises: + ValueError: if step_str is not a valid step. """ - return ".".join([str(x) for x in self.field_list]) - - def __add__(self, other: Union["Path", str]) -> "Path": - if isinstance(other, str): - other = create_path(other) - return self.concat(other) - - -def is_valid_step(step_str: str) -> bool: - """Return true if step_str is a valid step (see create_path).""" - return re.match( - "(?:" + _EXTENSION_REGEX + "|" + _SIMPLE_STEP_REGEX + "|" + - _MAP_INDEXING_STEP_REGEX + ")$", step_str, re.VERBOSE) is not None - - -def is_extension(step_str: str) -> bool: - """Return true if step_str is an extension or Any. - - Args: - step_str: the string to evaluate - - Returns: - True if step_str is an extension - Raises: - ValueError: if step_str is not a valid step. - """ - if not is_valid_step(step_str): - raise ValueError('Not a valid step in a path: "' + step_str + '"') - return step_str[0] == "(" + if not is_valid_step(step_str): + raise ValueError('Not a valid step in a path: "' + step_str + '"') + return step_str[0] == "(" def get_raw_extension_name(step_str: str) -> str: - """Gets the step without the parentheses.""" - if not is_valid_step(step_str): - raise ValueError('Not a valid step in a path: "' + step_str + '"') - if not is_extension(step_str): - raise ValueError('Not an extension: "' + step_str + '"') - return step_str[1:-1] + """Gets the step without the parentheses.""" + if not is_valid_step(step_str): + raise ValueError('Not a valid step in a path: "' + step_str + '"') + if not is_extension(step_str): + raise ValueError('Not an extension: "' + step_str + '"') + return step_str[1:-1] # The purpose of this type is to make it easy to write down paths as literals. @@ -305,143 +320,150 @@ def get_raw_extension_name(step_str: str) -> str: def is_map_indexing_step(step: str) -> bool: - return re.match(_MAP_INDEXING_STEP_REGEX, step, re.VERBOSE) is not None + return re.match(_MAP_INDEXING_STEP_REGEX, step, re.VERBOSE) is not None def parse_map_indexing_step(step: str) -> Tuple[str, str]: - first_bracket = step.find("[") - return step[:first_bracket], step[first_bracket + 1:-1] + first_bracket = step.find("[") + return step[:first_bracket], step[first_bracket + 1 : -1] def create_path(path_source: CoercableToPath) -> Path: - """Create a path from an object. - - The BNF for a path is: - letter := [A-Za-z] - digit := [0-9] - := "_"|"-"| | letter | digit - := + - := "(" ( ".")* ")" - := | - := (( ".") * )? - - TODO(martinz): consider removing dash. This would break YouTube WatchNext. - - Args: - path_source: a string or a Path object. - - Returns: - A Path. - - Raises: - ValueError: if this is not a valid path. - """ - if isinstance(path_source, Path): - return path_source - if path_source and path_source[-1] == ".": - # If we removed this then the period at the end would be ignored, and - # "foo.bar." would become ['foo', 'bar'] - raise ValueError("Path cannot end with .") - result = [] - path_remaining = path_source - # Capture a simple or extension step, then capture the next dot or end. - path_step_separator_re = re.compile( - "(" + _EXTENSION_REGEX + "|" + _SIMPLE_STEP_REGEX + "|" + - _MAP_INDEXING_STEP_REGEX + r""")(\.|$)""", re.VERBOSE) - while path_remaining: - next_match = path_step_separator_re.match(path_remaining) - if next_match: - result.append(next_match.group(1)) - path_remaining = path_remaining[next_match.end():] - else: - raise ValueError("Malformed path: " + path_source) - return Path(result) + """Create a path from an object. + + The BNF for a path is: + letter := [A-Za-z] + digit := [0-9] + := "_"|"-"| | letter | digit + := + + := "(" ( ".")* ")" + := | + := (( ".") * )? + + TODO(martinz): consider removing dash. This would break YouTube WatchNext. + + Args: + path_source: a string or a Path object. + + Returns: + A Path. + + Raises: + ValueError: if this is not a valid path. + """ + if isinstance(path_source, Path): + return path_source + if path_source and path_source[-1] == ".": + # If we removed this then the period at the end would be ignored, and + # "foo.bar." would become ['foo', 'bar'] + raise ValueError("Path cannot end with .") + result = [] + path_remaining = path_source + # Capture a simple or extension step, then capture the next dot or end. + path_step_separator_re = re.compile( + "(" + + _EXTENSION_REGEX + + "|" + + _SIMPLE_STEP_REGEX + + "|" + + _MAP_INDEXING_STEP_REGEX + + r""")(\.|$)""", + re.VERBOSE, + ) + while path_remaining: + next_match = path_step_separator_re.match(path_remaining) + if next_match: + result.append(next_match.group(1)) + path_remaining = path_remaining[next_match.end() :] + else: + raise ValueError("Malformed path: " + path_source) + return Path(result) def from_proto(path_proto: tf_metadata_path_pb2.Path) -> Path: - # Coerce each step to a native string. The steps in the proto are always - # Unicode strings, but the Path class may contain either unicode or bytes - # depending on whether this module is loaded with Python2 or Python3. - return Path([str(step) for step in path_proto.step]) + # Coerce each step to a native string. The steps in the proto are always + # Unicode strings, but the Path class may contain either unicode or bytes + # depending on whether this module is loaded with Python2 or Python3. + return Path([str(step) for step in path_proto.step]) def expand_wildcard_proto_paths( paths: Sequence[List[str]], proto_descriptor: descriptor.Descriptor ) -> List[Path]: - """Expands wildcard paths in the given sequence. - - Each path is represented as a list of str. Supported wildcard patterns are - single wildcard or paths ending with wildcard. - For example, ["*"] or ["FieldA", "FieldB", "*"]. - Partial matching paths are invalid, for example, ["Field*"]. - - Args: - paths: a list of Proto field paths to be expanded. - proto_descriptor: proto descriptor. - - Returns: - A list of Path objects. - - Raises: - ValueError: if any expanded path is also provided by user. - """ - if not paths: - return [] - - paths_with_wildcards = [] - result = set() - for path in paths: - if path == ["*"]: - return _get_all_subfields(proto_descriptor) - # path prefixes ending with "*" will be expanded - if path[-1] == "*" and len(path) > 1: - paths_with_wildcards.append(path[:-1]) - else: - # partial matching paths (e.g., ["feature*"]) will raise an error - # when converted to Path objects. - result.add(Path(path)) - - # expand path prefixes ending with wildcard - if paths_with_wildcards: - for pattern in paths_with_wildcards: - expanded_paths = _proto_glob(pattern, proto_descriptor) - for p in expanded_paths: - if p in result: - raise ValueError(f"Duplicate path {p} is detected.") - result.add(p) - - return sorted(result) + """Expands wildcard paths in the given sequence. + + Each path is represented as a list of str. Supported wildcard patterns are + single wildcard or paths ending with wildcard. + For example, ["*"] or ["FieldA", "FieldB", "*"]. + Partial matching paths are invalid, for example, ["Field*"]. + + Args: + paths: a list of Proto field paths to be expanded. + proto_descriptor: proto descriptor. + + Returns: + A list of Path objects. + + Raises: + ValueError: if any expanded path is also provided by user. + """ + if not paths: + return [] + + paths_with_wildcards = [] + result = set() + for path in paths: + if path == ["*"]: + return _get_all_subfields(proto_descriptor) + # path prefixes ending with "*" will be expanded + if path[-1] == "*" and len(path) > 1: + paths_with_wildcards.append(path[:-1]) + else: + # partial matching paths (e.g., ["feature*"]) will raise an error + # when converted to Path objects. + result.add(Path(path)) + + # expand path prefixes ending with wildcard + if paths_with_wildcards: + for pattern in paths_with_wildcards: + expanded_paths = _proto_glob(pattern, proto_descriptor) + for p in expanded_paths: + if p in result: + raise ValueError(f"Duplicate path {p} is detected.") + result.add(p) + + return sorted(result) def _get_all_subfields(proto_descriptor: descriptor.Descriptor) -> List[Path]: - """Extracts all subfield paths for the given descriptor.""" - result = [] - messages = set() - - def _recurse(d, prefix): - messages.add(d) - for f in d.fields: - if f.message_type in messages: - raise ValueError("Proto recursion is detected.") - prefix.append(f.name) - if f.message_type: - _recurse(f.message_type, prefix) - else: - result.append(Path(prefix)) - prefix.pop() - - _recurse(proto_descriptor, []) - return result + """Extracts all subfield paths for the given descriptor.""" + result = [] + messages = set() + + def _recurse(d, prefix): + messages.add(d) + for f in d.fields: + if f.message_type in messages: + raise ValueError("Proto recursion is detected.") + prefix.append(f.name) + if f.message_type: + _recurse(f.message_type, prefix) + else: + result.append(Path(prefix)) + prefix.pop() + + _recurse(proto_descriptor, []) + return result def _proto_glob( pattern: List[str], proto_descriptor: descriptor.Descriptor ) -> List[Path]: - """Returns proto paths matching the given prefix pattern.""" - parent = Path(pattern) - for field_name in parent.field_list: - if field_name not in proto_descriptor.fields_by_name: - raise ValueError(f"Field name {field_name} does not exist.") - proto_descriptor = proto_descriptor.fields_by_name[field_name].message_type - expanded_leaves = _get_all_subfields(proto_descriptor) - return [parent + child for child in expanded_leaves] + """Returns proto paths matching the given prefix pattern.""" + parent = Path(pattern) + for field_name in parent.field_list: + if field_name not in proto_descriptor.fields_by_name: + raise ValueError(f"Field name {field_name} does not exist.") + proto_descriptor = proto_descriptor.fields_by_name[field_name].message_type + expanded_leaves = _get_all_subfields(proto_descriptor) + return [parent + child for child in expanded_leaves] diff --git a/struct2tensor/path_test.py b/struct2tensor/path_test.py index dcf7423..2a0ed41 100644 --- a/struct2tensor/path_test.py +++ b/struct2tensor/path_test.py @@ -19,11 +19,11 @@ from absl.testing import absltest, parameterized from struct2tensor.path import ( - Path, - create_path, - expand_wildcard_proto_paths, - from_proto, - parse_map_indexing_step, + Path, + create_path, + expand_wildcard_proto_paths, + from_proto, + parse_map_indexing_step, ) from struct2tensor.test import test_pb2 @@ -86,8 +86,10 @@ ["event", "action", "number_of_views"], ["event", "user_info", "gender"], ], - "expected": [Path(["event", "action", "number_of_views"]), - Path(["event", "user_info", "gender"])], + "expected": [ + Path(["event", "action", "number_of_views"]), + Path(["event", "user_info", "gender"]), + ], }, { "testcase_name": "test_oneof", @@ -144,263 +146,270 @@ class PathTest(parameterized.TestCase): - - def test_get_child(self): - original_path = create_path("foo.bar") - child_path = original_path.get_child("baz") - self.assertEqual(str(child_path), "foo.bar.baz") - - def test_get_child_path_with_complex_steps(self): - # When provided as a string, "foo.bar.bat" is coerced into a path where - # the dots are interpreted as structure delimiters. - original_path = create_path("foo.bar.bat") - child_path = original_path.get_child("baz") - prefix = original_path.prefix(1) - suffix = original_path.suffix(1) - self.assertCountEqual( - child_path.concat(prefix).concat(suffix).field_list, - ["foo", "bar", "bat", "baz", "foo", "bar", "bat"], - ) - self.assertCountEqual(child_path.field_list, ["foo", "bar", "bat", "baz"]) - - # A path can be created where such a string is a single step, by disabling - # the default validation. - original_path = Path(["foo.bar.bat"], validate_step_format=False) - child_path = original_path.get_child("baz") - prefix = original_path.prefix(1) - suffix = original_path.suffix(1) - self.assertCountEqual( - child_path.concat(prefix).concat(suffix).field_list, - ["foo.bar.bat", "baz", "foo.bar.bat"], - ) - self.assertCountEqual(child_path.field_list, ["foo.bar.bat", "baz"]) - - another_path = Path(["foo/bar/bat"], validate_step_format=False) - ancestor = another_path.get_least_common_ancestor(original_path) - self.assertEmpty(ancestor.field_list) - - def test_get_child_root(self): - original_path = create_path("") - child_path = original_path.get_child("baz") - self.assertEqual(str(child_path), "baz") - - def test_root_path(self): - original_path = create_path("") - self.assertEmpty(original_path.field_list) - - def test_eq(self): - self.assertEqual(create_path("foo.bar"), create_path("foo.bar")) - self.assertNotEqual(create_path("foo.bar"), create_path("foo.baz")) - self.assertNotEqual(create_path("foo"), create_path("foo.baz")) - - def test_cmp(self): - self.assertGreater(Path([1]), Path("foo")) - self.assertLess(Path("foo"), Path([1])) - self.assertGreater(Path([1]), Path([0])) - self.assertGreater(create_path("foo.baz"), create_path("foo")) - self.assertGreater(create_path("foo.baz"), create_path("foo.bar")) - self.assertLess(create_path("foo"), create_path("foo.bar")) - self.assertLess(create_path("foo.bar"), create_path("foo.baz")) - self.assertEqual(create_path("foo.baz"), create_path("foo.baz")) - - def test_extension_begin(self): - original_path = create_path("(foo.bar.baz.ext).bee") - self.assertEqual("(foo.bar.baz.ext)", str(original_path.get_parent())) - - def test_extension_end(self): - original_path = create_path("bee.(foo.bar.baz.ext)") - self.assertEqual("(foo.bar.baz.ext)", original_path.field_list[1]) - - def test_extension_middle(self): - original_path = create_path("bee.(foo.bar.baz.ext).boo") - self.assertEqual("(foo.bar.baz.ext)", original_path.field_list[1]) - self.assertEqual("boo", original_path.field_list[2]) - - def test_multiple_extensions(self): - original_path = create_path("bee.(foo.bar.ext).boo.(foo.bar.ext2).buu") - self.assertEqual("(foo.bar.ext)", original_path.field_list[1]) - self.assertEqual("(foo.bar.ext2)", original_path.field_list[3]) - - def test_hash(self): - self.assertEqual(hash(create_path("foo.bar")), hash(create_path("foo.bar"))) - - def test_nonzero(self): - self.assertTrue(bool(create_path("foo"))) - self.assertFalse(bool(create_path(""))) - - def test_get_parent(self): - self.assertEqual(create_path("foo.bar").get_parent(), create_path("foo")) - - def test_concat(self): - self.assertEqual( - create_path("foo.bar").concat(create_path("baz.bax")), - create_path("foo.bar.baz.bax")) - - def test_get_least_common_ancestor(self): - self.assertEqual( - create_path("foo.bar.bax").get_least_common_ancestor( - create_path("foo.bar.baz")), create_path("foo.bar")) - - def test_get_least_common_ancestor_root(self): - self.assertEqual( - create_path("foo.bar.bax").get_least_common_ancestor( - create_path("different.bar.baz")), create_path("")) - - def test_is_ancestor(self): - ancestor = create_path("foo.bar") - descendant = create_path("foo.bar.baz") - not_ancestor = create_path("fuzz") - self.assertEqual(ancestor.is_ancestor(descendant), True) - self.assertEqual(ancestor.is_ancestor(ancestor), True) - self.assertEqual(not_ancestor.is_ancestor(descendant), False) - self.assertEqual(descendant.is_ancestor(ancestor), False) - - def test_dash_in_name(self): - simple_step = create_path("foo-bar.baz") - self.assertEqual(str(simple_step), "foo-bar.baz") - first_extension = create_path("(foo-bar.bak).baz") - self.assertEqual(str(first_extension), "(foo-bar.bak).baz") - last_extension = create_path("(foo.bar-bak).baz") - self.assertEqual(str(last_extension), "(foo.bar-bak).baz") - - def test_suffix(self): - original = create_path("foo.bar.baz") - self.assertEqual(str(original.suffix(0)), "foo.bar.baz") - self.assertEqual(str(original.suffix(1)), "bar.baz") - self.assertEqual(str(original.suffix(2)), "baz") - self.assertEqual(str(original.suffix(3)), "") - self.assertEqual(str(original.suffix(-1)), "baz") - self.assertEqual(str(original.suffix(-2)), "bar.baz") - - def test_prefix(self): - original = create_path("foo.bar.baz") - self.assertEqual(str(original.prefix(0)), "") - self.assertEqual(str(original.prefix(1)), "foo") - self.assertEqual(str(original.prefix(2)), "foo.bar") - self.assertEqual(str(original.prefix(3)), "foo.bar.baz") - self.assertEqual(str(original.prefix(-1)), "foo.bar") - self.assertEqual(str(original.prefix(-2)), "foo") - - def test_len(self): - self.assertLen(create_path(""), 0) # pylint: disable=g-generic-assert - self.assertLen(create_path("foo"), 1) - self.assertLen(create_path("foo.bar"), 2) - self.assertLen(create_path("foo.bar.baz"), 3) - - def test_str(self): - self.assertEqual(str(create_path("foo.bar.bax")), "foo.bar.bax") - - def test_empty_step(self): - with self.assertRaises(ValueError): - create_path("foo..bax") - - def test_non_alpha(self): - with self.assertRaises(ValueError): - create_path("foo.+.bax") - - def test_empty_extension_step(self): - with self.assertRaises(ValueError): - create_path("foo.(.baz).bax") - - def test_get_child_bad_field(self): - with self.assertRaises(ValueError): - create_path("foo.bax").get_child("foo.bar") - - def test_empty_extension(self): - with self.assertRaises(ValueError): - create_path("foo.().bax") - - def test_mismatched_open_paren(self): - with self.assertRaises(ValueError): - create_path("foo.(bar.bax") - - def test_mismatched_closed_paren(self): - with self.assertRaises(ValueError): - create_path("foo.bar).bax") - - def test_matched_middle_paren(self): - with self.assertRaises(ValueError): - create_path("foo.b(ar).bax") - - def test_valid_map_indexing_step(self): - self.assertSequenceEqual( - ["my_map[some_key]", "some_value"], - create_path("my_map[some_key].some_value").field_list) - self.assertSequenceEqual( - ["my_map[]", "some_value"], - create_path("my_map[].some_value").field_list) - self.assertSequenceEqual( - ["my_map[key.1]", "some_value"], - create_path("my_map[key.1].some_value").field_list) - self.assertSequenceEqual( - ["my_map[(key)]", "some_value"], - create_path("my_map[(key)].some_value").field_list) - self.assertSequenceEqual( - ["my_map[[]", "some_value"], - create_path("my_map[[].some_value").field_list) - - def test_invalid_map_indexing_step(self): - with self.assertRaises(ValueError): - create_path("[mymap[s].some_value") - with self.assertRaises(ValueError): - create_path("mymap[]].some_value") - - def test_parse_map_indexing_step(self): - map_field_name, map_key = parse_map_indexing_step("my_map[some_key]") - self.assertEqual("my_map", map_field_name) - self.assertEqual("some_key", map_key) - - map_field_name, map_key = parse_map_indexing_step("my_map[]") - self.assertEqual("my_map", map_field_name) - self.assertEqual("", map_key) - - map_field_name, map_key = parse_map_indexing_step("my_map[[.]") - self.assertEqual("my_map", map_field_name) - self.assertEqual("[.", map_key) - - def test_as_proto(self): - p = create_path("foo.(bar.e).baz") - path_proto = p.as_proto() - self.assertSequenceEqual(path_proto.step, ["foo", "(bar.e)", "baz"]) - - def test_proto_roundtrip(self): - p = create_path("foo.(bar.e).baz") - path_proto = p.as_proto() - p_from_proto = from_proto(path_proto) - self.assertEqual(p, p_from_proto) - self.assertEqual(path_proto, p_from_proto.as_proto()) - - def test_pprint(self): - p = create_path("foo.bar.baz") - self.assertEqual(pprint.pformat(p), "foo.bar.baz") - - def test_add(self): - # Test add two paths. - self.assertEqual( - create_path("foo.bar") + create_path("baz.bax"), - create_path("foo.bar.baz.bax"), - ) - - # Test add a path with a string. - self.assertEqual( - create_path("foo.bar") + "baz.bax", create_path("foo.bar.baz.bax") - ) - - @parameterized.named_parameters(*_FORMAT_TEST_CASES) - def test_expand_paths_ending_in_wildcard( - self, descriptor, input_paths, expected - ): - self.assertCountEqual( - expand_wildcard_proto_paths(input_paths, descriptor), expected - ) - - @parameterized.named_parameters(*_FORMAT_TEST_CASES_FAILED) - def test_expand_paths_ending_in_wildcard_recursion_proto( - self, descriptor, input_paths, expected_error_message - ): - with self.assertRaisesRegex(ValueError, expected_error_message): - expand_wildcard_proto_paths(input_paths, descriptor) + def test_get_child(self): + original_path = create_path("foo.bar") + child_path = original_path.get_child("baz") + self.assertEqual(str(child_path), "foo.bar.baz") + + def test_get_child_path_with_complex_steps(self): + # When provided as a string, "foo.bar.bat" is coerced into a path where + # the dots are interpreted as structure delimiters. + original_path = create_path("foo.bar.bat") + child_path = original_path.get_child("baz") + prefix = original_path.prefix(1) + suffix = original_path.suffix(1) + self.assertCountEqual( + child_path.concat(prefix).concat(suffix).field_list, + ["foo", "bar", "bat", "baz", "foo", "bar", "bat"], + ) + self.assertCountEqual(child_path.field_list, ["foo", "bar", "bat", "baz"]) + + # A path can be created where such a string is a single step, by disabling + # the default validation. + original_path = Path(["foo.bar.bat"], validate_step_format=False) + child_path = original_path.get_child("baz") + prefix = original_path.prefix(1) + suffix = original_path.suffix(1) + self.assertCountEqual( + child_path.concat(prefix).concat(suffix).field_list, + ["foo.bar.bat", "baz", "foo.bar.bat"], + ) + self.assertCountEqual(child_path.field_list, ["foo.bar.bat", "baz"]) + + another_path = Path(["foo/bar/bat"], validate_step_format=False) + ancestor = another_path.get_least_common_ancestor(original_path) + self.assertEmpty(ancestor.field_list) + + def test_get_child_root(self): + original_path = create_path("") + child_path = original_path.get_child("baz") + self.assertEqual(str(child_path), "baz") + + def test_root_path(self): + original_path = create_path("") + self.assertEmpty(original_path.field_list) + + def test_eq(self): + self.assertEqual(create_path("foo.bar"), create_path("foo.bar")) + self.assertNotEqual(create_path("foo.bar"), create_path("foo.baz")) + self.assertNotEqual(create_path("foo"), create_path("foo.baz")) + + def test_cmp(self): + self.assertGreater(Path([1]), Path("foo")) + self.assertLess(Path("foo"), Path([1])) + self.assertGreater(Path([1]), Path([0])) + self.assertGreater(create_path("foo.baz"), create_path("foo")) + self.assertGreater(create_path("foo.baz"), create_path("foo.bar")) + self.assertLess(create_path("foo"), create_path("foo.bar")) + self.assertLess(create_path("foo.bar"), create_path("foo.baz")) + self.assertEqual(create_path("foo.baz"), create_path("foo.baz")) + + def test_extension_begin(self): + original_path = create_path("(foo.bar.baz.ext).bee") + self.assertEqual("(foo.bar.baz.ext)", str(original_path.get_parent())) + + def test_extension_end(self): + original_path = create_path("bee.(foo.bar.baz.ext)") + self.assertEqual("(foo.bar.baz.ext)", original_path.field_list[1]) + + def test_extension_middle(self): + original_path = create_path("bee.(foo.bar.baz.ext).boo") + self.assertEqual("(foo.bar.baz.ext)", original_path.field_list[1]) + self.assertEqual("boo", original_path.field_list[2]) + + def test_multiple_extensions(self): + original_path = create_path("bee.(foo.bar.ext).boo.(foo.bar.ext2).buu") + self.assertEqual("(foo.bar.ext)", original_path.field_list[1]) + self.assertEqual("(foo.bar.ext2)", original_path.field_list[3]) + + def test_hash(self): + self.assertEqual(hash(create_path("foo.bar")), hash(create_path("foo.bar"))) + + def test_nonzero(self): + self.assertTrue(bool(create_path("foo"))) + self.assertFalse(bool(create_path(""))) + + def test_get_parent(self): + self.assertEqual(create_path("foo.bar").get_parent(), create_path("foo")) + + def test_concat(self): + self.assertEqual( + create_path("foo.bar").concat(create_path("baz.bax")), + create_path("foo.bar.baz.bax"), + ) + + def test_get_least_common_ancestor(self): + self.assertEqual( + create_path("foo.bar.bax").get_least_common_ancestor( + create_path("foo.bar.baz") + ), + create_path("foo.bar"), + ) + + def test_get_least_common_ancestor_root(self): + self.assertEqual( + create_path("foo.bar.bax").get_least_common_ancestor( + create_path("different.bar.baz") + ), + create_path(""), + ) + + def test_is_ancestor(self): + ancestor = create_path("foo.bar") + descendant = create_path("foo.bar.baz") + not_ancestor = create_path("fuzz") + self.assertEqual(ancestor.is_ancestor(descendant), True) + self.assertEqual(ancestor.is_ancestor(ancestor), True) + self.assertEqual(not_ancestor.is_ancestor(descendant), False) + self.assertEqual(descendant.is_ancestor(ancestor), False) + + def test_dash_in_name(self): + simple_step = create_path("foo-bar.baz") + self.assertEqual(str(simple_step), "foo-bar.baz") + first_extension = create_path("(foo-bar.bak).baz") + self.assertEqual(str(first_extension), "(foo-bar.bak).baz") + last_extension = create_path("(foo.bar-bak).baz") + self.assertEqual(str(last_extension), "(foo.bar-bak).baz") + + def test_suffix(self): + original = create_path("foo.bar.baz") + self.assertEqual(str(original.suffix(0)), "foo.bar.baz") + self.assertEqual(str(original.suffix(1)), "bar.baz") + self.assertEqual(str(original.suffix(2)), "baz") + self.assertEqual(str(original.suffix(3)), "") + self.assertEqual(str(original.suffix(-1)), "baz") + self.assertEqual(str(original.suffix(-2)), "bar.baz") + + def test_prefix(self): + original = create_path("foo.bar.baz") + self.assertEqual(str(original.prefix(0)), "") + self.assertEqual(str(original.prefix(1)), "foo") + self.assertEqual(str(original.prefix(2)), "foo.bar") + self.assertEqual(str(original.prefix(3)), "foo.bar.baz") + self.assertEqual(str(original.prefix(-1)), "foo.bar") + self.assertEqual(str(original.prefix(-2)), "foo") + + def test_len(self): + self.assertLen(create_path(""), 0) # pylint: disable=g-generic-assert + self.assertLen(create_path("foo"), 1) + self.assertLen(create_path("foo.bar"), 2) + self.assertLen(create_path("foo.bar.baz"), 3) + + def test_str(self): + self.assertEqual(str(create_path("foo.bar.bax")), "foo.bar.bax") + + def test_empty_step(self): + with self.assertRaises(ValueError): + create_path("foo..bax") + + def test_non_alpha(self): + with self.assertRaises(ValueError): + create_path("foo.+.bax") + + def test_empty_extension_step(self): + with self.assertRaises(ValueError): + create_path("foo.(.baz).bax") + + def test_get_child_bad_field(self): + with self.assertRaises(ValueError): + create_path("foo.bax").get_child("foo.bar") + + def test_empty_extension(self): + with self.assertRaises(ValueError): + create_path("foo.().bax") + + def test_mismatched_open_paren(self): + with self.assertRaises(ValueError): + create_path("foo.(bar.bax") + + def test_mismatched_closed_paren(self): + with self.assertRaises(ValueError): + create_path("foo.bar).bax") + + def test_matched_middle_paren(self): + with self.assertRaises(ValueError): + create_path("foo.b(ar).bax") + + def test_valid_map_indexing_step(self): + self.assertSequenceEqual( + ["my_map[some_key]", "some_value"], + create_path("my_map[some_key].some_value").field_list, + ) + self.assertSequenceEqual( + ["my_map[]", "some_value"], create_path("my_map[].some_value").field_list + ) + self.assertSequenceEqual( + ["my_map[key.1]", "some_value"], + create_path("my_map[key.1].some_value").field_list, + ) + self.assertSequenceEqual( + ["my_map[(key)]", "some_value"], + create_path("my_map[(key)].some_value").field_list, + ) + self.assertSequenceEqual( + ["my_map[[]", "some_value"], create_path("my_map[[].some_value").field_list + ) + + def test_invalid_map_indexing_step(self): + with self.assertRaises(ValueError): + create_path("[mymap[s].some_value") + with self.assertRaises(ValueError): + create_path("mymap[]].some_value") + + def test_parse_map_indexing_step(self): + map_field_name, map_key = parse_map_indexing_step("my_map[some_key]") + self.assertEqual("my_map", map_field_name) + self.assertEqual("some_key", map_key) + + map_field_name, map_key = parse_map_indexing_step("my_map[]") + self.assertEqual("my_map", map_field_name) + self.assertEqual("", map_key) + + map_field_name, map_key = parse_map_indexing_step("my_map[[.]") + self.assertEqual("my_map", map_field_name) + self.assertEqual("[.", map_key) + + def test_as_proto(self): + p = create_path("foo.(bar.e).baz") + path_proto = p.as_proto() + self.assertSequenceEqual(path_proto.step, ["foo", "(bar.e)", "baz"]) + + def test_proto_roundtrip(self): + p = create_path("foo.(bar.e).baz") + path_proto = p.as_proto() + p_from_proto = from_proto(path_proto) + self.assertEqual(p, p_from_proto) + self.assertEqual(path_proto, p_from_proto.as_proto()) + + def test_pprint(self): + p = create_path("foo.bar.baz") + self.assertEqual(pprint.pformat(p), "foo.bar.baz") + + def test_add(self): + # Test add two paths. + self.assertEqual( + create_path("foo.bar") + create_path("baz.bax"), + create_path("foo.bar.baz.bax"), + ) + + # Test add a path with a string. + self.assertEqual( + create_path("foo.bar") + "baz.bax", create_path("foo.bar.baz.bax") + ) + + @parameterized.named_parameters(*_FORMAT_TEST_CASES) + def test_expand_paths_ending_in_wildcard(self, descriptor, input_paths, expected): + self.assertCountEqual( + expand_wildcard_proto_paths(input_paths, descriptor), expected + ) + + @parameterized.named_parameters(*_FORMAT_TEST_CASES_FAILED) + def test_expand_paths_ending_in_wildcard_recursion_proto( + self, descriptor, input_paths, expected_error_message + ): + with self.assertRaisesRegex(ValueError, expected_error_message): + expand_wildcard_proto_paths(input_paths, descriptor) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/struct2tensor/prensor.py b/struct2tensor/prensor.py index caa166e..b4c54eb 100644 --- a/struct2tensor/prensor.py +++ b/struct2tensor/prensor.py @@ -27,7 +27,7 @@ import tensorflow as tf from tensorflow.python.framework import ( - composite_tensor, # pylint: disable=g-direct-tensorflow-import + composite_tensor, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor import calculate_options, path @@ -38,827 +38,854 @@ # ChildNodeTensor, and RootNodeTensor, allowing expression.py to depend upon # node.py. class RootNodeTensor: - """The value of the root.""" + """The value of the root.""" - __slots__ = ["_size"] + __slots__ = ["_size"] - def __init__(self, size: tf.Tensor): - """Creates a root node. + def __init__(self, size: tf.Tensor): + """Creates a root node. - Args: - size: A scalar int64 tensor saying how many root objects there are. - """ - self._size = size + Args: + size: A scalar int64 tensor saying how many root objects there are. + """ + self._size = size - @property - def size(self): - return self._size + @property + def size(self): + return self._size - @property - def is_repeated(self): - return True + @property + def is_repeated(self): + return True - def get_positional_index(self) -> tf.Tensor: - """Gets the positional index for this RootNodeTensor. + def get_positional_index(self) -> tf.Tensor: + """Gets the positional index for this RootNodeTensor. - The positional index relative to the node's parent, and thus is always - monotonically increasing at step size 1 for a RootNodeTensor. + The positional index relative to the node's parent, and thus is always + monotonically increasing at step size 1 for a RootNodeTensor. - Returns: - A tensor of positional indices. - """ - return tf.range(self.size) + Returns: + A tensor of positional indices. + """ + return tf.range(self.size) - def __str__(self): - return "RootNodeTensor" + def __str__(self): + return "RootNodeTensor" class ChildNodeTensor: - """The value of an intermediate node.""" + """The value of an intermediate node.""" - __slots__ = ["_parent_index", "_is_repeated", "_index_to_value"] + __slots__ = ["_parent_index", "_is_repeated", "_index_to_value"] - def __init__(self, - parent_index: tf.Tensor, - is_repeated: bool, - index_to_value: Optional[tf.Tensor] = None): - """Creates a child node. + def __init__( + self, + parent_index: tf.Tensor, + is_repeated: bool, + index_to_value: Optional[tf.Tensor] = None, + ): + """Creates a child node. - Args: - parent_index: a 1-D int64 tensor where parent_index[i] represents the - parent index of the ith child. - is_repeated: a bool indicating if there can be more than one child per - parent. - index_to_value: a 1-D int64 tensor where index_to_value[i] represents the - `value` of the ith child. Where `value` is a subtree. - """ - self._parent_index = parent_index - self._is_repeated = is_repeated - self._index_to_value = index_to_value + Args: + parent_index: a 1-D int64 tensor where parent_index[i] represents the + parent index of the ith child. + is_repeated: a bool indicating if there can be more than one child per + parent. + index_to_value: a 1-D int64 tensor where index_to_value[i] represents the + `value` of the ith child. Where `value` is a subtree. + """ + self._parent_index = parent_index + self._is_repeated = is_repeated + self._index_to_value = index_to_value - @property - def size(self): - """Returns the size, as if this was the root prensor. + @property + def size(self): + """Returns the size, as if this was the root prensor. - Returns: - A scalar int64 tensor. - """ - return tf.size(self._parent_index, out_type=tf.int64) + Returns: + A scalar int64 tensor. + """ + return tf.size(self._parent_index, out_type=tf.int64) - @property - def parent_index(self): - return self._parent_index + @property + def parent_index(self): + return self._parent_index - @property - def is_repeated(self): - return self._is_repeated + @property + def is_repeated(self): + return self._is_repeated - @property - def index_to_value(self): - return self._index_to_value + @property + def index_to_value(self): + return self._index_to_value - # LINT.IfChange(child_node_tensor) - def get_positional_index(self) -> tf.Tensor: - """Gets the positional index for this ChildNodeTensor. + # LINT.IfChange(child_node_tensor) + def get_positional_index(self) -> tf.Tensor: + """Gets the positional index for this ChildNodeTensor. - The positional index tells us which index of the parent an element is. + The positional index tells us which index of the parent an element is. - For example, with the following parent indices: [0, 0, 2] - we would have positional index: - [ - 0, # The 0th element of the 0th parent. - 1, # The 1st element of the 0th parent. - 0 # The 0th element of the 2nd parent. - ]. + For example, with the following parent indices: [0, 0, 2] + we would have positional index: + [ + 0, # The 0th element of the 0th parent. + 1, # The 1st element of the 0th parent. + 0 # The 0th element of the 2nd parent. + ]. - For more information, view ops/run_length_before_op.cc + For more information, view ops/run_length_before_op.cc - This is the same for Leaf NodeTensors. + This is the same for Leaf NodeTensors. - Returns: - A tensor of positional indices. - """ - return struct2tensor_ops.run_length_before(self.parent_index) - # LINT.ThenChange(:leaf_node_tensor) + Returns: + A tensor of positional indices. + """ + return struct2tensor_ops.run_length_before(self.parent_index) + + # LINT.ThenChange(:leaf_node_tensor) - def __str__(self): - cardinality = "repeated" if self.is_repeated else "optional" - return f"{cardinality} ChildNodeTensor" + def __str__(self): + cardinality = "repeated" if self.is_repeated else "optional" + return f"{cardinality} ChildNodeTensor" class LeafNodeTensor: - """The value of a leaf node.""" + """The value of a leaf node.""" - __slots__ = ["_parent_index", "_values", "_is_repeated"] + __slots__ = ["_parent_index", "_values", "_is_repeated"] - def __init__(self, parent_index: tf.Tensor, values: tf.Tensor, - is_repeated: bool): - """Creates a LeafNodeTensor. + def __init__(self, parent_index: tf.Tensor, values: tf.Tensor, is_repeated: bool): + """Creates a LeafNodeTensor. - Args: - parent_index: a 1-D int64 tensor where parent_index[i] represents the - parent index of values[i] - values: a 1-D tensor of equal length to parent_index. - is_repeated: a bool indicating if there can be more than one child per - parent. - """ - self._parent_index = parent_index - self._values = values - self._is_repeated = is_repeated + Args: + parent_index: a 1-D int64 tensor where parent_index[i] represents the + parent index of values[i] + values: a 1-D tensor of equal length to parent_index. + is_repeated: a bool indicating if there can be more than one child per + parent. + """ + self._parent_index = parent_index + self._values = values + self._is_repeated = is_repeated - @property - def parent_index(self): - return self._parent_index + @property + def parent_index(self): + return self._parent_index - @property - def is_repeated(self): - return self._is_repeated + @property + def is_repeated(self): + return self._is_repeated - @property - def values(self): - return self._values + @property + def values(self): + return self._values - # LINT.IfChange(leaf_node_tensor) - def get_positional_index(self) -> tf.Tensor: - """Gets the positional index for this LeafNodeTensor. + # LINT.IfChange(leaf_node_tensor) + def get_positional_index(self) -> tf.Tensor: + """Gets the positional index for this LeafNodeTensor. - The positional index tells us which index of the parent an element is. + The positional index tells us which index of the parent an element is. - For example, with the following parent indices: [0, 0, 2] - we would have positional index: - [ - 0, # The 0th element of the 0th parent. - 1, # The 1st element of the 0th parent. - 0 # The 0th element of the 2nd parent. - ]. + For example, with the following parent indices: [0, 0, 2] + we would have positional index: + [ + 0, # The 0th element of the 0th parent. + 1, # The 1st element of the 0th parent. + 0 # The 0th element of the 2nd parent. + ]. - For more information, view ops/run_length_before_op.cc + For more information, view ops/run_length_before_op.cc - This is the same for Child NodeTensors. + This is the same for Child NodeTensors. - Returns: - A tensor of positional indices. - """ - return struct2tensor_ops.run_length_before(self.parent_index) - # LINT.ThenChange(:child_node_tensor) + Returns: + A tensor of positional indices. + """ + return struct2tensor_ops.run_length_before(self.parent_index) - def __str__(self): - return "{} {}".format("repeated" if self.is_repeated else "optional", - str(self.values.dtype)) + # LINT.ThenChange(:child_node_tensor) + + def __str__(self): + return "{} {}".format( + "repeated" if self.is_repeated else "optional", str(self.values.dtype) + ) def create_required_leaf_node(values: tf.Tensor) -> LeafNodeTensor: - """Create a required leaf node.""" - return LeafNodeTensor( - tf.range(tf.size(values, out_type=tf.int64)), values, False) + """Create a required leaf node.""" + return LeafNodeTensor(tf.range(tf.size(values, out_type=tf.int64)), values, False) NodeTensor = Union[LeafNodeTensor, ChildNodeTensor, RootNodeTensor] # pylint: disable=invalid-name class _PrensorTypeSpec(tf.TypeSpec): - """TypeSpec for Prensor.""" - - class _NodeType(enum.IntEnum): - ROOT = 1 - CHILD = 2 - LEAF = 3 - - __slots__ = [ - "_is_repeated", "_node_type", "_value_dtype", "_children_specs"] - - def __init__(self, is_repeated: Optional[bool], node_type: _NodeType, - value_dtype: Optional[tf.DType], - children_specs: List[Tuple[path.Step, "_PrensorTypeSpec"]]): - self._is_repeated = is_repeated - self._node_type = node_type - self._value_dtype = value_dtype - self._children_specs = children_specs - - @property - def value_type(self): - return Prensor - - def _append_to_components(self, - value: "Prensor", - components: List[tf.Tensor]): - node = value.node - if self._node_type == self._NodeType.ROOT: - assert isinstance(node, RootNodeTensor) - components.append(node.size) - elif self._node_type == self._NodeType.CHILD: - assert isinstance(node, ChildNodeTensor) - components.append(node.parent_index) - else: - assert isinstance(node, LeafNodeTensor) - components.append(node.parent_index) - components.append(node.values) - - for (_, child_spec), child in zip( - self._children_specs, value.get_children().values()): - child_spec._append_to_components( # pylint: disable=protected-access - child, components) - - def _to_components(self, value: "Prensor") -> List[tf.Tensor]: - """Encodes a Prensor into a tuple of lists of Tensors. - - Args: - value: A Prensor. - - Returns: - A list of Tensors which are what each NodeTensor in a Prensor consists of. - They are grouped by their owning NodeTensors in pre-order (the children - of one node are ordered the same as the _children OrderedDict of that - node). - """ - result = [] - self._append_to_components(value, result) - return result - - def _from_component_iter( - self, - component_iter: Iterator[tf.Tensor]) -> "Prensor": - if self._node_type == self._NodeType.ROOT: - node = RootNodeTensor(next(component_iter)) - elif self._node_type == self._NodeType.CHILD: - node = ChildNodeTensor(next(component_iter), self._is_repeated) - else: - leaf_parent_index = next(component_iter) - leaf_values = next(component_iter) - node = LeafNodeTensor(leaf_parent_index, leaf_values, self._is_repeated) - step_to_child = collections.OrderedDict() - for step, child_spec in self._children_specs: - step_to_child[step] = ( - child_spec._from_component_iter( # pylint: disable=protected-access - component_iter)) - return Prensor(node, step_to_child) - - def _from_components(self, components: List[tf.Tensor]) -> "Prensor": - """Creates a Prensor from the components encoded by _to_components().""" - return self._from_component_iter(iter(components)) - - def _append_to_component_specs(self, component_specs: List[tf.TensorSpec]): - if self._node_type == self._NodeType.ROOT: - component_specs.append(tf.TensorSpec([], tf.int64)) - elif self._node_type == self._NodeType.CHILD: - component_specs.append(tf.TensorSpec([None], tf.int64)) - else: - component_specs.append(tf.TensorSpec([None], tf.int64)) - component_specs.append( - tf.TensorSpec([None], self._value_dtype)) - for _, child_spec in self._children_specs: - child_spec._append_to_component_specs( # pylint: disable=protected-access - component_specs) - - @property - def _component_specs(self) -> List[tf.TensorSpec]: - """Returns TensorSpecs for each tensors returned by _to_components(). - - Returns: - a Tuple of Lists of the same structure as _to_components() Returns. - """ - result = [] - self._append_to_component_specs(result) - return result - - def _serialize(self) -> Tuple[bool, int, tf.DType, Tuple]: # pylint: disable=g-bare-generic - return (self._is_repeated, int(self._node_type), self._value_dtype, - tuple((step, - child_spec._serialize()) # pylint: disable=protected-access - for step, child_spec in self._children_specs)) - - @classmethod - def _deserialize(cls, serialization: Tuple[bool, int, tf.DType, Tuple]): # pylint: disable=g-bare-generic - children_serializations = serialization[3] - children_specs = [(step, cls._deserialize(child_serialization)) - for step, child_serialization in children_serializations] - return cls( - serialization[0], - cls._NodeType(serialization[1]), - serialization[2], - children_specs) - - def _to_legacy_output_types(self): - return tuple(spec.dtype for spec in self._component_specs) - - def _to_legacy_output_shapes(self): - return tuple(spec.shape for spec in self._component_specs) - - def _to_legacy_output_classes(self): - return tuple(tf.Tensor for spec in self._component_specs) + """TypeSpec for Prensor.""" + + class _NodeType(enum.IntEnum): + ROOT = 1 + CHILD = 2 + LEAF = 3 + + __slots__ = ["_is_repeated", "_node_type", "_value_dtype", "_children_specs"] + + def __init__( + self, + is_repeated: Optional[bool], + node_type: _NodeType, + value_dtype: Optional[tf.DType], + children_specs: List[Tuple[path.Step, "_PrensorTypeSpec"]], + ): + self._is_repeated = is_repeated + self._node_type = node_type + self._value_dtype = value_dtype + self._children_specs = children_specs + + @property + def value_type(self): + return Prensor + + def _append_to_components(self, value: "Prensor", components: List[tf.Tensor]): + node = value.node + if self._node_type == self._NodeType.ROOT: + assert isinstance(node, RootNodeTensor) + components.append(node.size) + elif self._node_type == self._NodeType.CHILD: + assert isinstance(node, ChildNodeTensor) + components.append(node.parent_index) + else: + assert isinstance(node, LeafNodeTensor) + components.append(node.parent_index) + components.append(node.values) + + for (_, child_spec), child in zip( + self._children_specs, value.get_children().values() + ): + child_spec._append_to_components( # pylint: disable=protected-access + child, components + ) + + def _to_components(self, value: "Prensor") -> List[tf.Tensor]: + """Encodes a Prensor into a tuple of lists of Tensors. + + Args: + value: A Prensor. + + Returns: + A list of Tensors which are what each NodeTensor in a Prensor consists of. + They are grouped by their owning NodeTensors in pre-order (the children + of one node are ordered the same as the _children OrderedDict of that + node). + """ + result = [] + self._append_to_components(value, result) + return result + + def _from_component_iter(self, component_iter: Iterator[tf.Tensor]) -> "Prensor": + if self._node_type == self._NodeType.ROOT: + node = RootNodeTensor(next(component_iter)) + elif self._node_type == self._NodeType.CHILD: + node = ChildNodeTensor(next(component_iter), self._is_repeated) + else: + leaf_parent_index = next(component_iter) + leaf_values = next(component_iter) + node = LeafNodeTensor(leaf_parent_index, leaf_values, self._is_repeated) + step_to_child = collections.OrderedDict() + for step, child_spec in self._children_specs: + step_to_child[step] = child_spec._from_component_iter( # pylint: disable=protected-access + component_iter + ) + return Prensor(node, step_to_child) + + def _from_components(self, components: List[tf.Tensor]) -> "Prensor": + """Creates a Prensor from the components encoded by _to_components().""" + return self._from_component_iter(iter(components)) + + def _append_to_component_specs(self, component_specs: List[tf.TensorSpec]): + if self._node_type == self._NodeType.ROOT: + component_specs.append(tf.TensorSpec([], tf.int64)) + elif self._node_type == self._NodeType.CHILD: + component_specs.append(tf.TensorSpec([None], tf.int64)) + else: + component_specs.append(tf.TensorSpec([None], tf.int64)) + component_specs.append(tf.TensorSpec([None], self._value_dtype)) + for _, child_spec in self._children_specs: + child_spec._append_to_component_specs( # pylint: disable=protected-access + component_specs + ) + + @property + def _component_specs(self) -> List[tf.TensorSpec]: + """Returns TensorSpecs for each tensors returned by _to_components(). + + Returns: + a Tuple of Lists of the same structure as _to_components() Returns. + """ + result = [] + self._append_to_component_specs(result) + return result + + def _serialize(self) -> Tuple[bool, int, tf.DType, Tuple]: # pylint: disable=g-bare-generic + return ( + self._is_repeated, + int(self._node_type), + self._value_dtype, + tuple( + (step, child_spec._serialize()) # pylint: disable=protected-access + for step, child_spec in self._children_specs + ), + ) + + @classmethod + def _deserialize(cls, serialization: Tuple[bool, int, tf.DType, Tuple]): # pylint: disable=g-bare-generic + children_serializations = serialization[3] + children_specs = [ + (step, cls._deserialize(child_serialization)) + for step, child_serialization in children_serializations + ] + return cls( + serialization[0], + cls._NodeType(serialization[1]), + serialization[2], + children_specs, + ) + + def _to_legacy_output_types(self): + return tuple(spec.dtype for spec in self._component_specs) + + def _to_legacy_output_shapes(self): + return tuple(spec.shape for spec in self._component_specs) + + def _to_legacy_output_classes(self): + return tuple(tf.Tensor for spec in self._component_specs) class Prensor(composite_tensor.CompositeTensor): - """A expression of NodeTensor objects.""" - - __slots__ = ["_node", "_children"] - - def __init__(self, node: NodeTensor, - children: "collections.OrderedDict[path.Step, Prensor]"): - """Construct a Prensor. + """A expression of NodeTensor objects.""" + + __slots__ = ["_node", "_children"] + + def __init__( + self, node: NodeTensor, children: "collections.OrderedDict[path.Step, Prensor]" + ): + """Construct a Prensor. + + Do not call directly, instead call either: + create_prensor_from_descendant_nodes or + create_prensor_from_root_and_children + + Args: + node: the NodeTensor of the root. + children: a map from edge to subexpression. + """ + self._node = node + self._children = children + + @property + def node(self) -> NodeTensor: + """The node of the root of the subtree.""" + return self._node + + def get_child(self, field_name: path.Step) -> Optional["Prensor"]: + """Gets the child at field_name.""" + return self._children.get(field_name) + + @property + def is_leaf(self) -> bool: + """True iff the node value is a LeafNodeTensor.""" + return isinstance(self._node, LeafNodeTensor) + + def get_child_or_error(self, field_name: path.Step) -> "Prensor": + """Gets the child at field_name.""" + result = self._children.get(field_name) + if result is not None: + return result + raise ValueError(f"Field not found: {str(field_name)} in {str(self)}") + + def get_descendant(self, p: path.Path) -> Optional["Prensor"]: + """Finds the descendant at the path.""" + result = self + for field_name in p.field_list: + result = result.get_child(field_name) + if result is None: + return None + return result + + def get_descendant_or_error(self, p: path.Path) -> "Prensor": + """Finds the descendant at the path.""" + result = self.get_descendant(p) + if result is None: + raise ValueError(f"Missing path: {str(p)} in {str(self)}") + return result + + def get_children(self) -> "collections.OrderedDict[path.Step, Prensor]": + """A map from field name to subexpression.""" + return self._children + + def get_descendants(self) -> Mapping[path.Path, "Prensor"]: + """A map from paths to all subexpressions.""" + result = {path.Path([]): self} + for k, v in self._children.items(): + subexpression_descendants = v.get_descendants() + for k2, v2 in subexpression_descendants.items(): + # Since k is already a path.Step, we know it is valid and needn't + # validate it again. + result[path.Path([k], validate_step_format=False).concat(k2)] = v2 + return result + + def field_names(self) -> FrozenSet[path.Step]: + """Returns the field names of the children.""" + return frozenset(self._children.keys()) + + def get_ragged_tensors( + self, + options: calculate_options.Options = calculate_options.get_default_options(), + ) -> Mapping[path.Path, tf.RaggedTensor]: + """Gets ragged tensors for all the leaves of the prensor expression. + + Args: + options: Options for calculating ragged tensors. + + Returns: + A map from paths to ragged tensors. + """ + return _get_ragged_tensors(self, options=options) + + def get_ragged_tensor( + self, + p: path.Path, + options: calculate_options.Options = calculate_options.get_default_options(), + ) -> tf.RaggedTensor: + """Get a ragged tensor for a path. + + All steps are represented in the ragged tensor. + + Args: + p: the path to a leaf node in `t`. + options: Options for calculating ragged tensors. + + Returns: + A ragged tensor containing values of the leaf node, preserving the + structure along the path. Raises an error if the path is not found. + """ + return _get_ragged_tensor(self, p, options=options) + + def get_sparse_tensor( + self, + p: path.Path, + options: calculate_options.Options = calculate_options.get_default_options(), + ) -> tf.SparseTensor: + """Gets a sparse tensor for path p. + + Note that any optional fields are not registered as dimensions, as they + can't be represented in a sparse tensor. + + Args: + p: The path to a leaf node in `t`. + options: Currently unused. + + Returns: + A sparse tensor containing values of the leaf node, preserving the + structure along the path. Raises an error if the path is not found. + """ + return _get_sparse_tensor(self, p, options=options) + + def get_sparse_tensors( + self, + options: calculate_options.Options = calculate_options.get_default_options(), + ) -> Mapping[path.Path, tf.SparseTensor]: + """Gets sparse tensors for all the leaves of the prensor expression. + + Args: + options: Currently unused. + + Returns: + A map from paths to sparse tensors. + """ + return _get_sparse_tensors(self, options=options) + + def _string_helper(self, field_name: path.Step) -> Sequence[str]: + """Helper for __str__ that outputs a list of lines. + + Args: + field_name: the field name for this node in its parent. + + Returns: + lines to run __str__, that are bytes in Python 2 and unicode in Python 3. + """ + result = [f"{str(self.node)} {str(field_name)}"] + for k, v in self._children.items(): + recursive = v._string_helper(k) # pylint: disable=protected-access + result.extend([f" {x}" for x in recursive]) + return result + + def __str__(self) -> str: + """Returns a string representing the schema of the Prensor.""" + return "\n".join(self._string_helper("root")) + + @property + def _type_spec(self) -> _PrensorTypeSpec: + is_repeated = None + value_dtype = None + # pylint: disable=protected-access + if isinstance(self.node, RootNodeTensor): + node_type = _PrensorTypeSpec._NodeType.ROOT + elif isinstance(self.node, ChildNodeTensor): + is_repeated = self.node.is_repeated + node_type = _PrensorTypeSpec._NodeType.CHILD + else: + is_repeated = self.node.is_repeated + node_type = _PrensorTypeSpec._NodeType.LEAF + value_dtype = self.node.values.dtype + return _PrensorTypeSpec( + is_repeated, + node_type, + value_dtype, + [(step, child._type_spec) for step, child in self.get_children().items()], + ) + # pylint: enable=protected-access - Do not call directly, instead call either: - create_prensor_from_descendant_nodes or - create_prensor_from_root_and_children - - Args: - node: the NodeTensor of the root. - children: a map from edge to subexpression. - """ - self._node = node - self._children = children - - @property - def node(self) -> NodeTensor: - """The node of the root of the subtree.""" - return self._node - - def get_child(self, field_name: path.Step) -> Optional["Prensor"]: - """Gets the child at field_name.""" - return self._children.get(field_name) - - @property - def is_leaf(self) -> bool: - """True iff the node value is a LeafNodeTensor.""" - return isinstance(self._node, LeafNodeTensor) - - def get_child_or_error(self, field_name: path.Step) -> "Prensor": - """Gets the child at field_name.""" - result = self._children.get(field_name) - if result is not None: - return result - raise ValueError(f"Field not found: {str(field_name)} in {str(self)}") - - def get_descendant(self, p: path.Path) -> Optional["Prensor"]: - """Finds the descendant at the path.""" - result = self - for field_name in p.field_list: - result = result.get_child(field_name) - if result is None: - return None - return result - - def get_descendant_or_error(self, p: path.Path) -> "Prensor": - """Finds the descendant at the path.""" - result = self.get_descendant(p) - if result is None: - raise ValueError(f"Missing path: {str(p)} in {str(self)}") - return result - - def get_children(self) -> "collections.OrderedDict[path.Step, Prensor]": - """A map from field name to subexpression.""" - return self._children - - def get_descendants(self) -> Mapping[path.Path, "Prensor"]: - """A map from paths to all subexpressions.""" - result = {path.Path([]): self} - for k, v in self._children.items(): - subexpression_descendants = v.get_descendants() - for k2, v2 in subexpression_descendants.items(): - # Since k is already a path.Step, we know it is valid and needn't - # validate it again. - result[path.Path([k], validate_step_format=False).concat(k2)] = v2 - return result - - def field_names(self) -> FrozenSet[path.Step]: - """Returns the field names of the children.""" - return frozenset(self._children.keys()) - - def get_ragged_tensors( - self, - options: calculate_options.Options = calculate_options - .get_default_options() - ) -> Mapping[path.Path, tf.RaggedTensor]: - """Gets ragged tensors for all the leaves of the prensor expression. - - Args: - options: Options for calculating ragged tensors. - Returns: - A map from paths to ragged tensors. - """ - return _get_ragged_tensors(self, options=options) - - def get_ragged_tensor( - self, - p: path.Path, - options: calculate_options.Options = calculate_options - .get_default_options() - ) -> tf.RaggedTensor: - """Get a ragged tensor for a path. - - All steps are represented in the ragged tensor. - - Args: - p: the path to a leaf node in `t`. - options: Options for calculating ragged tensors. - - Returns: - A ragged tensor containing values of the leaf node, preserving the - structure along the path. Raises an error if the path is not found. - """ - return _get_ragged_tensor(self, p, options=options) - - def get_sparse_tensor( - self, - p: path.Path, - options: calculate_options.Options = calculate_options - .get_default_options() - ) -> tf.SparseTensor: - """Gets a sparse tensor for path p. - - Note that any optional fields are not registered as dimensions, as they - can't be represented in a sparse tensor. - - Args: - p: The path to a leaf node in `t`. - options: Currently unused. - - Returns: - A sparse tensor containing values of the leaf node, preserving the - structure along the path. Raises an error if the path is not found. - """ - return _get_sparse_tensor(self, p, options=options) +def create_prensor_from_descendant_nodes( + nodes: Mapping[path.Path, NodeTensor], +) -> "Prensor": + """Create a prensor from a map of paths to NodeTensor. - def get_sparse_tensors( - self, - options: calculate_options.Options = calculate_options - .get_default_options() - ) -> Mapping[path.Path, tf.SparseTensor]: - """Gets sparse tensors for all the leaves of the prensor expression. + If a path is a key in the map, all prefixes of that path must be present. Args: - options: Currently unused. + nodes: A map from paths to NodeTensors. Returns: - A map from paths to sparse tensors. - """ - return _get_sparse_tensors(self, options=options) - - def _string_helper(self, field_name: path.Step) -> Sequence[str]: - """Helper for __str__ that outputs a list of lines. - - Args: - field_name: the field name for this node in its parent. + A Prensor. - Returns: - lines to run __str__, that are bytes in Python 2 and unicode in Python 3. + Raises: + ValueError: if there is a prefix of a path missing. """ - result = [f"{str(self.node)} {str(field_name)}"] - for k, v in self._children.items(): - recursive = v._string_helper(k) # pylint: disable=protected-access - result.extend([f" {x}" for x in recursive]) - return result - - def __str__(self) -> str: - """Returns a string representing the schema of the Prensor.""" - return "\n".join(self._string_helper("root")) - - @property - def _type_spec(self) -> _PrensorTypeSpec: - is_repeated = None - value_dtype = None - # pylint: disable=protected-access - if isinstance(self.node, RootNodeTensor): - node_type = _PrensorTypeSpec._NodeType.ROOT - elif isinstance(self.node, ChildNodeTensor): - is_repeated = self.node.is_repeated - node_type = _PrensorTypeSpec._NodeType.CHILD - else: - is_repeated = self.node.is_repeated - node_type = _PrensorTypeSpec._NodeType.LEAF - value_dtype = self.node.values.dtype - return _PrensorTypeSpec( - is_repeated, - node_type, - value_dtype, - [(step, child._type_spec) - for step, child in self.get_children().items()]) - # pylint: enable=protected-access - - -def create_prensor_from_descendant_nodes( - nodes: Mapping[path.Path, NodeTensor]) -> "Prensor": - """Create a prensor from a map of paths to NodeTensor. - - If a path is a key in the map, all prefixes of that path must be present. - - Args: - nodes: A map from paths to NodeTensors. + subexpressions = collections.OrderedDict() + root_node = None + for k, v in sorted(nodes.items()): + if not k: + root_node = v + else: + first_step = k.field_list[0] + suffix = k.suffix(1) + if first_step not in subexpressions: + subexpressions[first_step] = {} + subexpressions[first_step][suffix] = v + if root_node is None: + raise ValueError(f"No root found: {str(nodes)}") + return create_prensor_from_root_and_children( + root_node, + collections.OrderedDict( + (k, create_prensor_from_descendant_nodes(v)) + for k, v in subexpressions.items() + ), + ) - Returns: - A Prensor. - Raises: - ValueError: if there is a prefix of a path missing. - """ - subexpressions = collections.OrderedDict() - root_node = None - for k, v in sorted(nodes.items()): - if not k: - root_node = v +def create_prensor_from_root_and_children( + root: NodeTensor, children: Mapping[path.Step, Prensor] +) -> Prensor: + if isinstance(children, collections.OrderedDict): + ordered_children = children else: - first_step = k.field_list[0] - suffix = k.suffix(1) - if first_step not in subexpressions: - subexpressions[first_step] = {} - subexpressions[first_step][suffix] = v - if root_node is None: - raise ValueError(f"No root found: {str(nodes)}") - return create_prensor_from_root_and_children( - root_node, - collections.OrderedDict((k, create_prensor_from_descendant_nodes(v)) - for k, v in subexpressions.items())) - + ordered_children = collections.OrderedDict(sorted(children.items())) + return Prensor(root, ordered_children) -def create_prensor_from_root_and_children( - root: NodeTensor, children: Mapping[path.Step, Prensor]) -> Prensor: - if isinstance(children, collections.OrderedDict): - ordered_children = children - else: - ordered_children = collections.OrderedDict(sorted(children.items())) - return Prensor(root, ordered_children) ######## Below code is for converting prensor to ragged/sparse tensors.######### class _LeafNodePath: - """A path ending in a leaf. + """A path ending in a leaf. - In order to avoid type checks and casting in the heart of different methods - using the Prensor object to get a ragged or sparse tensor, we first create a - typed "list" of nodes. A _LeafNodePath always begins with the root and ends - with a leaf. Notice that we can get a suffix by casting a child node to a - root node. - """ + In order to avoid type checks and casting in the heart of different methods + using the Prensor object to get a ragged or sparse tensor, we first create a + typed "list" of nodes. A _LeafNodePath always begins with the root and ends + with a leaf. Notice that we can get a suffix by casting a child node to a + root node. + """ - def __init__(self, head: RootNodeTensor, middle: Sequence[ChildNodeTensor], - tail: LeafNodeTensor): - self._head = head - self._middle = middle - self._tail = tail + def __init__( + self, + head: RootNodeTensor, + middle: Sequence[ChildNodeTensor], + tail: LeafNodeTensor, + ): + self._head = head + self._middle = middle + self._tail = tail - @property - def head(self) -> RootNodeTensor: - return self._head + @property + def head(self) -> RootNodeTensor: + return self._head - @property - def middle(self) -> Sequence[ChildNodeTensor]: - return self._middle + @property + def middle(self) -> Sequence[ChildNodeTensor]: + return self._middle - @property - def tail(self) -> LeafNodeTensor: - return self._tail + @property + def tail(self) -> LeafNodeTensor: + return self._tail class _ChildNodePath: - """A _ChildNodePath is a path that ends with a child node. + """A _ChildNodePath is a path that ends with a child node. - It keeps same triple structure as _LeafNodePath. - We use these in _get_dewey_encoding. - """ + It keeps same triple structure as _LeafNodePath. + We use these in _get_dewey_encoding. + """ - def __init__(self, head: RootNodeTensor, middle: Sequence[ChildNodeTensor], - tail: ChildNodeTensor): - self._head = head - self._middle = middle - self._tail = tail + def __init__( + self, + head: RootNodeTensor, + middle: Sequence[ChildNodeTensor], + tail: ChildNodeTensor, + ): + self._head = head + self._middle = middle + self._tail = tail - @property - def head(self) -> RootNodeTensor: - return self._head + @property + def head(self) -> RootNodeTensor: + return self._head - @property - def middle(self) -> Sequence[ChildNodeTensor]: - return self._middle + @property + def middle(self) -> Sequence[ChildNodeTensor]: + return self._middle - @property - def tail(self) -> ChildNodeTensor: - return self._tail + @property + def tail(self) -> ChildNodeTensor: + return self._tail def _as_root_node_tensor(node_tensor: NodeTensor) -> RootNodeTensor: - if isinstance(node_tensor, RootNodeTensor): - return node_tensor - if isinstance(node_tensor, ChildNodeTensor): - return RootNodeTensor(node_tensor.size) - raise ValueError(f"Must be child or root node tensor (found {type(node_tensor)})") + if isinstance(node_tensor, RootNodeTensor): + return node_tensor + if isinstance(node_tensor, ChildNodeTensor): + return RootNodeTensor(node_tensor.size) + raise ValueError(f"Must be child or root node tensor (found {type(node_tensor)})") def _get_leaf_node_path(p: path.Path, t: Prensor) -> _LeafNodePath: - """Creates a _LeafNodePath to p.""" - leaf_node = t.get_descendant_or_error(p).node - if not isinstance(leaf_node, LeafNodeTensor): - raise ValueError(f"Expected Leaf Node at {str(p)} in {str(t)}") - if not p: - raise ValueError("Leaf should not be at the root") - # If there is a leaf at the root, this will return a ValueError. - root_node = _as_root_node_tensor(t.node) - - # Not the root, not p. - strict_ancestor_paths = [p.prefix(i) for i in range(1, len(p))] - - child_node_pairs = [(t.get_descendant_or_error(ancestor).node, ancestor) - for ancestor in strict_ancestor_paths] - bad_struct_paths = [ - ancestor for node, ancestor in child_node_pairs - if not isinstance(node, ChildNodeTensor) - ] - if bad_struct_paths: - raise ValueError("Expected ChildNodeTensor at {} in {}".format( - " ".join([str(x) for x in bad_struct_paths]), str(t))) - # This should select all elements: the isinstance is for type-checking. - child_nodes = [ - node for node, ancestor in child_node_pairs - if isinstance(node, ChildNodeTensor) - ] - assert len(child_nodes) == len(child_node_pairs) - return _LeafNodePath(root_node, child_nodes, leaf_node) + """Creates a _LeafNodePath to p.""" + leaf_node = t.get_descendant_or_error(p).node + if not isinstance(leaf_node, LeafNodeTensor): + raise ValueError(f"Expected Leaf Node at {str(p)} in {str(t)}") + if not p: + raise ValueError("Leaf should not be at the root") + # If there is a leaf at the root, this will return a ValueError. + root_node = _as_root_node_tensor(t.node) + + # Not the root, not p. + strict_ancestor_paths = [p.prefix(i) for i in range(1, len(p))] + + child_node_pairs = [ + (t.get_descendant_or_error(ancestor).node, ancestor) + for ancestor in strict_ancestor_paths + ] + bad_struct_paths = [ + ancestor + for node, ancestor in child_node_pairs + if not isinstance(node, ChildNodeTensor) + ] + if bad_struct_paths: + raise ValueError( + "Expected ChildNodeTensor at {} in {}".format( + " ".join([str(x) for x in bad_struct_paths]), str(t) + ) + ) + # This should select all elements: the isinstance is for type-checking. + child_nodes = [ + node for node, ancestor in child_node_pairs if isinstance(node, ChildNodeTensor) + ] + assert len(child_nodes) == len(child_node_pairs) + return _LeafNodePath(root_node, child_nodes, leaf_node) def _get_leaf_node_path_suffix(p: _LeafNodePath) -> _LeafNodePath: - """Get the suffix of a LeafNodePath.""" - return _LeafNodePath(_as_root_node_tensor(p.middle[0]), p.middle[1:], p.tail) + """Get the suffix of a LeafNodePath.""" + return _LeafNodePath(_as_root_node_tensor(p.middle[0]), p.middle[1:], p.tail) -def _get_node_path_parent( - p: Union[_LeafNodePath, _ChildNodePath]) -> _ChildNodePath: - return _ChildNodePath(p.head, p.middle[:-1], p.middle[-1]) +def _get_node_path_parent(p: Union[_LeafNodePath, _ChildNodePath]) -> _ChildNodePath: + return _ChildNodePath(p.head, p.middle[:-1], p.middle[-1]) def _get_leaf_node_paths(t: Prensor) -> Mapping[path.Path, _LeafNodePath]: - """Gets a map of paths to leaf nodes in the expression.""" - return { - k: _get_leaf_node_path(k, t) - for k, v in t.get_descendants().items() - if isinstance(v.node, LeafNodeTensor) - } + """Gets a map of paths to leaf nodes in the expression.""" + return { + k: _get_leaf_node_path(k, t) + for k, v in t.get_descendants().items() + if isinstance(v.node, LeafNodeTensor) + } #################### Code for _get_sparse_tensors(...) ######################### def _get_dewey_encoding( - p: Union[_LeafNodePath, _ChildNodePath]) -> Tuple[tf.Tensor, tf.Tensor]: - """Gets a pair of the indices and shape of these protos. - - See http://db.ucsd.edu/static/cse232B-s05/papers/tatarinov02.pdf - - Args: - p: the path to encode. - - Returns: - A pair of an indices matrix and a dense_shape - """ - parent = p.middle[-1] if p.middle else p.head - parent_size = tf.reshape(parent.size, [1]) - positional_index = p.tail.get_positional_index() - # tf.reduce_max([]) == -kmaxint64 but we need it to be 0. - current_size = tf.maximum( - tf.reshape(tf.reduce_max(positional_index) + 1, [1]), [0]) - if not p.middle: + p: Union[_LeafNodePath, _ChildNodePath], +) -> Tuple[tf.Tensor, tf.Tensor]: + """Gets a pair of the indices and shape of these protos. + + See http://db.ucsd.edu/static/cse232B-s05/papers/tatarinov02.pdf + + Args: + p: the path to encode. + + Returns: + A pair of an indices matrix and a dense_shape + """ + parent = p.middle[-1] if p.middle else p.head + parent_size = tf.reshape(parent.size, [1]) + positional_index = p.tail.get_positional_index() + # tf.reduce_max([]) == -kmaxint64 but we need it to be 0. + current_size = tf.maximum(tf.reshape(tf.reduce_max(positional_index) + 1, [1]), [0]) + if not p.middle: + if p.tail.is_repeated: + return ( + tf.stack([p.tail.parent_index, positional_index], axis=1), + tf.concat([parent_size, current_size], 0), + ) + return tf.expand_dims(p.tail.parent_index, -1), parent_size + parent_dewey_encoding, parent_size = _get_dewey_encoding(_get_node_path_parent(p)) if p.tail.is_repeated: - return (tf.stack([p.tail.parent_index, positional_index], - axis=1), tf.concat([parent_size, current_size], 0)) - return tf.expand_dims(p.tail.parent_index, -1), parent_size - parent_dewey_encoding, parent_size = _get_dewey_encoding( - _get_node_path_parent(p)) - if p.tail.is_repeated: - positional_index_as_matrix = tf.expand_dims( - p.tail.get_positional_index(), -1) - indices = tf.concat([ - tf.gather(parent_dewey_encoding, p.tail.parent_index), - positional_index_as_matrix - ], 1) - size = tf.concat([parent_size, current_size], 0) - return (indices, size) - return tf.gather(parent_dewey_encoding, p.tail.parent_index), parent_size + positional_index_as_matrix = tf.expand_dims(p.tail.get_positional_index(), -1) + indices = tf.concat( + [ + tf.gather(parent_dewey_encoding, p.tail.parent_index), + positional_index_as_matrix, + ], + 1, + ) + size = tf.concat([parent_size, current_size], 0) + return (indices, size) + return tf.gather(parent_dewey_encoding, p.tail.parent_index), parent_size def _get_sparse_tensor_from_leaf_node_path(p: _LeafNodePath) -> tf.SparseTensor: - indices, dense_shape = _get_dewey_encoding(p) - return tf.SparseTensor( - indices=indices, values=p.tail.values, dense_shape=dense_shape) + indices, dense_shape = _get_dewey_encoding(p) + return tf.SparseTensor( + indices=indices, values=p.tail.values, dense_shape=dense_shape + ) def _get_sparse_tensor( t: Prensor, p: path.Path, - options: calculate_options.Options = calculate_options.get_default_options( - ) + options: calculate_options.Options = calculate_options.get_default_options(), ) -> tf.SparseTensor: - """Gets a sparse tensor for path p. + """Gets a sparse tensor for path p. - Note that any optional fields are not registered as dimensions, as they can't - be represented in a sparse tensor. + Note that any optional fields are not registered as dimensions, as they can't + be represented in a sparse tensor. - Args: - t: The Prensor to extract tensors from. - p: The path to a leaf node in `t`. - options: Currently unused. + Args: + t: The Prensor to extract tensors from. + p: The path to a leaf node in `t`. + options: Currently unused. - Returns: - A sparse tensor containing values of the leaf node, preserving the - structure along the path. Raises an error if the path is not found. - """ - del options - leaf_node_path = _get_leaf_node_path(p, t) - return _get_sparse_tensor_from_leaf_node_path(leaf_node_path) + Returns: + A sparse tensor containing values of the leaf node, preserving the + structure along the path. Raises an error if the path is not found. + """ + del options + leaf_node_path = _get_leaf_node_path(p, t) + return _get_sparse_tensor_from_leaf_node_path(leaf_node_path) def _get_sparse_tensors( t: Prensor, - options: calculate_options.Options = calculate_options.get_default_options( - ) + options: calculate_options.Options = calculate_options.get_default_options(), ) -> Mapping[path.Path, tf.SparseTensor]: - """Gets sparse tensors for all the leaves of the prensor expression. + """Gets sparse tensors for all the leaves of the prensor expression. - Args: - t: The Prensor to extract tensors from. - options: Currently unused. + Args: + t: The Prensor to extract tensors from. + options: Currently unused. - Returns: - A map from paths to sparse tensors. - """ - del options - return { - p: _get_sparse_tensor_from_leaf_node_path(v) - for p, v in _get_leaf_node_paths(t).items() - } + Returns: + A map from paths to sparse tensors. + """ + del options + return { + p: _get_sparse_tensor_from_leaf_node_path(v) + for p, v in _get_leaf_node_paths(t).items() + } #################### Code for _get_ragged_tensors(...) ######################### -def from_value_rowids_bridge(values, - value_rowids=None, - nrows=None, - validate=True): - """Validate option is only available internally for tf 0.13.1.""" - return tf.RaggedTensor.from_value_rowids( - values, value_rowids=value_rowids, nrows=nrows, validate=validate) +def from_value_rowids_bridge(values, value_rowids=None, nrows=None, validate=True): + """Validate option is only available internally for tf 0.13.1.""" + return tf.RaggedTensor.from_value_rowids( + values, value_rowids=value_rowids, nrows=nrows, validate=validate + ) def _get_ragged_tensor_from_leaf_node_path( nodes: _LeafNodePath, - options: calculate_options.Options = calculate_options.get_default_options( - ) + options: calculate_options.Options = calculate_options.get_default_options(), ) -> tf.RaggedTensor: - """Gets a ragged tensor from a leaf node path.""" - if not nodes.middle: + """Gets a ragged tensor from a leaf node path.""" + if not nodes.middle: + return from_value_rowids_bridge( + nodes.tail.values, + value_rowids=nodes.tail.parent_index, + nrows=nodes.head.size, + validate=options.ragged_checks, + ) + deeper_ragged = _get_ragged_tensor_from_leaf_node_path( + _get_leaf_node_path_suffix(nodes), options + ) + first_child_node = nodes.middle[0] return from_value_rowids_bridge( - nodes.tail.values, - value_rowids=nodes.tail.parent_index, + deeper_ragged, + value_rowids=first_child_node.parent_index, nrows=nodes.head.size, - validate=options.ragged_checks) - deeper_ragged = _get_ragged_tensor_from_leaf_node_path( - _get_leaf_node_path_suffix(nodes), options) - first_child_node = nodes.middle[0] - return from_value_rowids_bridge( - deeper_ragged, - value_rowids=first_child_node.parent_index, - nrows=nodes.head.size, - validate=options.ragged_checks) + validate=options.ragged_checks, + ) def _get_ragged_tensor( t: Prensor, p: path.Path, - options: calculate_options.Options = calculate_options.get_default_options( - ) + options: calculate_options.Options = calculate_options.get_default_options(), ) -> tf.RaggedTensor: - """Get a ragged tensor for a path. + """Get a ragged tensor for a path. - All steps are represented in the ragged tensor. + All steps are represented in the ragged tensor. - Args: - t: The Prensor to extract tensors from. - p: the path to a leaf node in `t`. - options: used to pass options for calculating ragged tensors. + Args: + t: The Prensor to extract tensors from. + p: the path to a leaf node in `t`. + options: used to pass options for calculating ragged tensors. - Returns: - A ragged tensor containing values of the leaf node, preserving the - structure along the path. Raises an error if the path is not found. - """ - leaf_node_path = _get_leaf_node_path(p, t) - return _get_ragged_tensor_from_leaf_node_path(leaf_node_path, options) + Returns: + A ragged tensor containing values of the leaf node, preserving the + structure along the path. Raises an error if the path is not found. + """ + leaf_node_path = _get_leaf_node_path(p, t) + return _get_ragged_tensor_from_leaf_node_path(leaf_node_path, options) def _get_ragged_tensors( t: Prensor, - options: calculate_options.Options = calculate_options.get_default_options( - ) + options: calculate_options.Options = calculate_options.get_default_options(), ) -> Mapping[path.Path, tf.RaggedTensor]: - """Gets ragged tensors for all the leaves of the prensor expression. - - Args: - t: The Prensor to extract tensors from. - options: used to pass options for calculating ragged tensors. - - Returns: - A map from paths to ragged tensors. - """ - return { - p: _get_ragged_tensor_from_leaf_node_path(v, options) - for p, v in _get_leaf_node_paths(t).items() - } + """Gets ragged tensors for all the leaves of the prensor expression. + + Args: + t: The Prensor to extract tensors from. + options: used to pass options for calculating ragged tensors. + + Returns: + A map from paths to ragged tensors. + """ + return { + p: _get_ragged_tensor_from_leaf_node_path(v, options) + for p, v in _get_leaf_node_paths(t).items() + } diff --git a/struct2tensor/prensor_test.py b/struct2tensor/prensor_test.py index 2dd27c7..d412451 100644 --- a/struct2tensor/prensor_test.py +++ b/struct2tensor/prensor_test.py @@ -16,7 +16,7 @@ import tensorflow as tf from tensorflow.python.framework import ( - test_util, # pylint: disable=g-direct-tensorflow-import + test_util, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor import calculate_options, path, prensor @@ -24,242 +24,272 @@ _OPTIONS_TO_TEST = [ calculate_options.get_default_options(), - calculate_options.get_options_with_minimal_checks() + calculate_options.get_options_with_minimal_checks(), ] @test_util.run_all_in_graph_and_eager_modes class PrensorTest(tf.test.TestCase): + def _assert_prensor_equals(self, lhs, rhs): + if isinstance(lhs.node, prensor.RootNodeTensor): + self.assertIsInstance(rhs.node, prensor.RootNodeTensor) + self.assertIs(lhs.node.size, rhs.node.size) + elif isinstance(lhs.node, prensor.ChildNodeTensor): + self.assertIsInstance(rhs.node, prensor.ChildNodeTensor) + self.assertIs(lhs.node.parent_index, rhs.node.parent_index) + self.assertEqual(lhs.node.is_repeated, rhs.node.is_repeated) + else: + self.assertIsInstance(rhs.node, prensor.LeafNodeTensor) + self.assertIs(lhs.node.parent_index, rhs.node.parent_index) + self.assertIs(lhs.node.values, rhs.node.values) + self.assertEqual(lhs.node.is_repeated, rhs.node.is_repeated) + + self.assertEqual(len(lhs.get_children()), len(rhs.get_children())) + for (l_child_step, l_child), (r_child_step, r_child) in zip( + lhs.get_children().items(), rhs.get_children().items() + ): + self.assertEqual(l_child_step, r_child_step) + self._assert_prensor_equals(l_child, r_child) + + def test_prensor_children_ordered(self): + def _recursively_check_sorted(p): + self.assertEqual( + list(p.get_children().keys()), sorted(p.get_children().keys()) + ) + for c in p.get_children().values(): + _recursively_check_sorted(c) + + for pren in [ + prensor_test_util.create_nested_prensor(), + prensor_test_util.create_big_prensor(), + prensor_test_util.create_deep_prensor(), + ]: + _recursively_check_sorted(pren) + + p = prensor.create_prensor_from_descendant_nodes( + { + path.Path([]): prensor_test_util.create_root_node(1), + path.Path(["d"]): prensor_test_util.create_optional_leaf_node( + [0], [True] + ), + path.Path(["c"]): prensor_test_util.create_optional_leaf_node( + [0], [True] + ), + path.Path(["b"]): prensor_test_util.create_optional_leaf_node( + [0], [True] + ), + path.Path(["a"]): prensor_test_util.create_optional_leaf_node( + [0], [True] + ), + } + ) + self.assertEqual(["a", "b", "c", "d"], list(p.get_children().keys())) + + def test_prensor_is_composite_tensor(self): + for pren in [ + prensor_test_util.create_nested_prensor(), + prensor_test_util.create_big_prensor(), + prensor_test_util.create_deep_prensor(), + ]: + flattened_tensors = tf.nest.flatten(pren, expand_composites=True) + self.assertIsInstance(flattened_tensors, list) + for t in flattened_tensors: + self.assertIsInstance(t, tf.Tensor) + packed_pren = tf.nest.pack_sequence_as( + pren, flattened_tensors, expand_composites=True + ) + self._assert_prensor_equals(pren, packed_pren) + + def test_prensor_to_ragged_tensors(self): + for options in _OPTIONS_TO_TEST: + pren = prensor_test_util.create_nested_prensor() + ragged_tensor_map = pren.get_ragged_tensors(options=options) + string_tensor_map = {str(k): v for k, v in ragged_tensor_map.items()} + string_np_map = self.evaluate(string_tensor_map) + self.assertAllEqual( + string_np_map["doc.bar"].to_list(), + [[[b"a"]], [[b"b", b"c"], [b"d"]], []], + ) + + self.assertAllEqual( + string_np_map["doc.keep_me"].to_list(), [[[False]], [[True], []], []] + ) + self.assertAllEqual( + string_np_map["user.friends"].to_list(), + [[[b"a"]], [[b"b", b"c"], [b"d"]], [[b"e"]]], + ) + + def test_prensor_to_ragged_tensor(self): + for options in _OPTIONS_TO_TEST: + pren = prensor_test_util.create_nested_prensor() + ragged_tensor = pren.get_ragged_tensor(path.create_path("doc.bar"), options) + self.assertAllEqual(ragged_tensor, [[[b"a"]], [[b"b", b"c"], [b"d"]], []]) + + def test_prensor_to_sparse_tensors(self): + for options in _OPTIONS_TO_TEST: + pren = prensor_test_util.create_nested_prensor() + sparse_tensor_map = pren.get_sparse_tensors(options=options) + string_tensor_map = {str(k): v for k, v in sparse_tensor_map.items()} + + self.assertAllEqual( + string_tensor_map["doc.bar"].indices, + [[0, 0, 0], [1, 0, 0], [1, 0, 1], [1, 1, 0]], + ) + self.assertAllEqual( + string_tensor_map["doc.bar"].values, [b"a", b"b", b"c", b"d"] + ) + self.assertAllEqual( + string_tensor_map["doc.keep_me"].indices, [[0, 0], [1, 0]] + ) + self.assertAllEqual(string_tensor_map["doc.keep_me"].values, [False, True]) + + self.assertAllEqual( + string_tensor_map["user.friends"].indices, + [[0, 0, 0], [1, 0, 0], [1, 0, 1], [1, 1, 0], [2, 0, 0]], + ) + self.assertAllEqual( + string_tensor_map["user.friends"].values, [b"a", b"b", b"c", b"d", b"e"] + ) + + def test_prensor_to_sparse_tensor(self): + for options in _OPTIONS_TO_TEST: + pren = prensor_test_util.create_simple_prensor() + sparse_tensor = pren.get_sparse_tensor( + path.create_path("foo"), options=options + ) + self.assertAllEqual(sparse_tensor.indices, [[0], [1], [2]]) + self.assertAllEqual(sparse_tensor.dense_shape, [3]) + self.assertAllEqual(sparse_tensor.values, [9, 8, 7]) + + def test_get_leaf_node_paths(self): + """Tests get_sparse_tensors on a deep expression.""" + expression = prensor_test_util.create_nested_prensor() + leaf_node_paths = prensor._get_leaf_node_paths(expression) + self.assertIn(path.Path(["doc", "bar"]), leaf_node_paths) + self.assertIn(path.Path(["user", "friends"]), leaf_node_paths) + self.assertIn(path.Path(["doc", "keep_me"]), leaf_node_paths) + + def test_is_leaf(self): + """Tests get_sparse_tensors on a deep expression.""" + expression = prensor_test_util.create_nested_prensor() + self.assertTrue( + expression.get_descendant_or_error(path.Path(["doc", "bar"])).is_leaf + ) + self.assertFalse(expression.get_descendant_or_error(path.Path(["doc"])).is_leaf) + self.assertFalse(expression.get_descendant_or_error(path.Path([])).is_leaf) + self.assertTrue( + expression.get_descendant_or_error(path.Path(["user", "friends"])).is_leaf + ) + self.assertTrue( + expression.get_descendant_or_error(path.Path(["doc", "keep_me"])).is_leaf + ) + + def test_get_sparse_tensors(self): + """Tests get_sparse_tensors on a deep expression.""" + for options in _OPTIONS_TO_TEST: + expression = prensor_test_util.create_nested_prensor() + sparse_tensor_map = prensor._get_sparse_tensors(expression, options) + string_tensor_map = {str(k): v for k, v in sparse_tensor_map.items()} + + self.assertAllEqual( + string_tensor_map["doc.bar"].indices, + [[0, 0, 0], [1, 0, 0], [1, 0, 1], [1, 1, 0]], + ) + self.assertAllEqual( + string_tensor_map["doc.bar"].values, [b"a", b"b", b"c", b"d"] + ) + self.assertAllEqual( + string_tensor_map["doc.keep_me"].indices, [[0, 0], [1, 0]] + ) + self.assertAllEqual(string_tensor_map["doc.keep_me"].values, [False, True]) + + self.assertAllEqual( + string_tensor_map["user.friends"].indices, + [[0, 0, 0], [1, 0, 0], [1, 0, 1], [1, 1, 0], [2, 0, 0]], + ) + self.assertAllEqual( + string_tensor_map["user.friends"].values, [b"a", b"b", b"c", b"d", b"e"] + ) + + def test_get_sparse_tensors_simple(self): + """Tests get_sparse_tensors on a deep expression.""" + for options in _OPTIONS_TO_TEST: + expression = prensor_test_util.create_simple_prensor() + sparse_tensor_map = prensor._get_sparse_tensors(expression, options) + string_tensor_map = {str(k): v for k, v in sparse_tensor_map.items()} + self.assertAllEqual(string_tensor_map["foo"].indices, [[0], [1], [2]]) + self.assertAllEqual(string_tensor_map["foo"].dense_shape, [3]) + + self.assertAllEqual(string_tensor_map["foo"].values, [9, 8, 7]) + self.assertAllEqual( + string_tensor_map["foorepeated"].indices, + [[0, 0], [1, 0], [1, 1], [2, 0]], + ) + self.assertAllEqual(string_tensor_map["foorepeated"].values, [9, 8, 7, 6]) + self.assertAllEqual(string_tensor_map["foorepeated"].dense_shape, [3, 2]) + + def test_get_sparse_tensor(self): + expression = prensor_test_util.create_simple_prensor() + sparse_tensor = prensor._get_sparse_tensor(expression, path.create_path("foo")) + self.assertAllEqual(sparse_tensor.indices, [[0], [1], [2]]) + self.assertAllEqual(sparse_tensor.dense_shape, [3]) + self.assertAllEqual(sparse_tensor.values, [9, 8, 7]) + + def test_get_sparse_tensors_simple_dense(self): + """Tests get_sparse_tensors on a deep expression.""" + for options in _OPTIONS_TO_TEST: + expression = prensor_test_util.create_simple_prensor() + sparse_tensor_map = prensor._get_sparse_tensors(expression, options) + string_tensor_map = { + str(k): tf.sparse.to_dense(v) for k, v in sparse_tensor_map.items() + } + + self.assertAllEqual(string_tensor_map["foo"], [9, 8, 7]) + self.assertAllEqual( + string_tensor_map["foorepeated"], [[9, 0], [8, 7], [6, 0]] + ) + + def test_broken_ragged_tensors_no_check(self): + """Make sure that it doesn't crash. The result is undefined.""" + expression = prensor_test_util.create_broken_prensor() + ragged_tensor_map = prensor._get_ragged_tensors( + expression, calculate_options.get_options_with_minimal_checks() + ) + string_tensor_map = {str(k): v for k, v in ragged_tensor_map.items()} + self.evaluate(string_tensor_map) + + # Okay, need to break this apart to handle the V1/V2 issues. + def test_get_ragged_tensors(self): + """Tests get_ragged_tensors on a deep expression.""" + for options in _OPTIONS_TO_TEST: + expression = prensor_test_util.create_nested_prensor() + ragged_tensor_map = prensor._get_ragged_tensors(expression, options) + string_tensor_map = {str(k): v for k, v in ragged_tensor_map.items()} + string_np_map = self.evaluate(string_tensor_map) + self.assertAllEqual( + string_np_map["doc.bar"].to_list(), + [[[b"a"]], [[b"b", b"c"], [b"d"]], []], + ) + + self.assertAllEqual( + string_np_map["doc.keep_me"].to_list(), [[[False]], [[True], []], []] + ) + self.assertAllEqual( + string_np_map["user.friends"].to_list(), + [[[b"a"]], [[b"b", b"c"], [b"d"]], [[b"e"]]], + ) + + def test_get_ragged_tensor(self): + """Tests get_ragged_tensor on a deep field.""" + for options in _OPTIONS_TO_TEST: + expression = prensor_test_util.create_nested_prensor() + ragged_tensor = prensor._get_ragged_tensor( + expression, path.create_path("doc.bar"), options + ) + self.assertAllEqual(ragged_tensor, [[[b"a"]], [[b"b", b"c"], [b"d"]], []]) - def _assert_prensor_equals(self, lhs, rhs): - if isinstance(lhs.node, prensor.RootNodeTensor): - self.assertIsInstance(rhs.node, prensor.RootNodeTensor) - self.assertIs(lhs.node.size, rhs.node.size) - elif isinstance(lhs.node, prensor.ChildNodeTensor): - self.assertIsInstance(rhs.node, prensor.ChildNodeTensor) - self.assertIs(lhs.node.parent_index, rhs.node.parent_index) - self.assertEqual(lhs.node.is_repeated, rhs.node.is_repeated) - else: - self.assertIsInstance(rhs.node, prensor.LeafNodeTensor) - self.assertIs(lhs.node.parent_index, rhs.node.parent_index) - self.assertIs(lhs.node.values, rhs.node.values) - self.assertEqual(lhs.node.is_repeated, rhs.node.is_repeated) - - self.assertEqual(len(lhs.get_children()), len(rhs.get_children())) - for (l_child_step, l_child), (r_child_step, r_child) in zip( - lhs.get_children().items(), rhs.get_children().items()): - self.assertEqual(l_child_step, r_child_step) - self._assert_prensor_equals(l_child, r_child) - - def test_prensor_children_ordered(self): - def _recursively_check_sorted(p): - self.assertEqual(list(p.get_children().keys()), - sorted(p.get_children().keys())) - for c in p.get_children().values(): - _recursively_check_sorted(c) - - for pren in [ - prensor_test_util.create_nested_prensor(), - prensor_test_util.create_big_prensor(), - prensor_test_util.create_deep_prensor() - ]: - _recursively_check_sorted(pren) - - p = prensor.create_prensor_from_descendant_nodes({ - path.Path([]): - prensor_test_util.create_root_node(1), - path.Path(["d"]): - prensor_test_util.create_optional_leaf_node([0], [True]), - path.Path(["c"]): - prensor_test_util.create_optional_leaf_node([0], [True]), - path.Path(["b"]): - prensor_test_util.create_optional_leaf_node([0], [True]), - path.Path(["a"]): - prensor_test_util.create_optional_leaf_node([0], [True]), - }) - self.assertEqual(["a", "b", "c", "d"], list(p.get_children().keys())) - - def test_prensor_is_composite_tensor(self): - for pren in [ - prensor_test_util.create_nested_prensor(), - prensor_test_util.create_big_prensor(), - prensor_test_util.create_deep_prensor() - ]: - flattened_tensors = tf.nest.flatten(pren, expand_composites=True) - self.assertIsInstance(flattened_tensors, list) - for t in flattened_tensors: - self.assertIsInstance(t, tf.Tensor) - packed_pren = tf.nest.pack_sequence_as( - pren, flattened_tensors, expand_composites=True) - self._assert_prensor_equals(pren, packed_pren) - - def test_prensor_to_ragged_tensors(self): - for options in _OPTIONS_TO_TEST: - pren = prensor_test_util.create_nested_prensor() - ragged_tensor_map = pren.get_ragged_tensors(options=options) - string_tensor_map = {str(k): v for k, v in ragged_tensor_map.items()} - string_np_map = self.evaluate(string_tensor_map) - self.assertAllEqual(string_np_map["doc.bar"].to_list(), - [[[b"a"]], [[b"b", b"c"], [b"d"]], []]) - - self.assertAllEqual(string_np_map["doc.keep_me"].to_list(), - [[[False]], [[True], []], []]) - self.assertAllEqual(string_np_map["user.friends"].to_list(), - [[[b"a"]], [[b"b", b"c"], [b"d"]], [[b"e"]]]) - - def test_prensor_to_ragged_tensor(self): - for options in _OPTIONS_TO_TEST: - pren = prensor_test_util.create_nested_prensor() - ragged_tensor = pren.get_ragged_tensor( - path.create_path("doc.bar"), options) - self.assertAllEqual(ragged_tensor, [[[b"a"]], [[b"b", b"c"], [b"d"]], []]) - - def test_prensor_to_sparse_tensors(self): - for options in _OPTIONS_TO_TEST: - pren = prensor_test_util.create_nested_prensor() - sparse_tensor_map = pren.get_sparse_tensors(options=options) - string_tensor_map = {str(k): v for k, v in sparse_tensor_map.items()} - - self.assertAllEqual(string_tensor_map["doc.bar"].indices, - [[0, 0, 0], [1, 0, 0], [1, 0, 1], [1, 1, 0]]) - self.assertAllEqual(string_tensor_map["doc.bar"].values, - [b"a", b"b", b"c", b"d"]) - self.assertAllEqual(string_tensor_map["doc.keep_me"].indices, - [[0, 0], [1, 0]]) - self.assertAllEqual(string_tensor_map["doc.keep_me"].values, - [False, True]) - - self.assertAllEqual( - string_tensor_map["user.friends"].indices, - [[0, 0, 0], [1, 0, 0], [1, 0, 1], [1, 1, 0], [2, 0, 0]]) - self.assertAllEqual(string_tensor_map["user.friends"].values, - [b"a", b"b", b"c", b"d", b"e"]) - - def test_prensor_to_sparse_tensor(self): - for options in _OPTIONS_TO_TEST: - pren = prensor_test_util.create_simple_prensor() - sparse_tensor = pren.get_sparse_tensor( - path.create_path("foo"), options=options) - self.assertAllEqual(sparse_tensor.indices, [[0], [1], [2]]) - self.assertAllEqual(sparse_tensor.dense_shape, [3]) - self.assertAllEqual(sparse_tensor.values, [9, 8, 7]) - - def test_get_leaf_node_paths(self): - """Tests get_sparse_tensors on a deep expression.""" - expression = prensor_test_util.create_nested_prensor() - leaf_node_paths = prensor._get_leaf_node_paths(expression) - self.assertIn(path.Path(["doc", "bar"]), leaf_node_paths) - self.assertIn(path.Path(["user", "friends"]), leaf_node_paths) - self.assertIn(path.Path(["doc", "keep_me"]), leaf_node_paths) - - def test_is_leaf(self): - """Tests get_sparse_tensors on a deep expression.""" - expression = prensor_test_util.create_nested_prensor() - self.assertTrue( - expression.get_descendant_or_error(path.Path(["doc", "bar"])).is_leaf) - self.assertFalse( - expression.get_descendant_or_error(path.Path(["doc"])).is_leaf) - self.assertFalse(expression.get_descendant_or_error(path.Path([])).is_leaf) - self.assertTrue( - expression.get_descendant_or_error(path.Path(["user", - "friends"])).is_leaf) - self.assertTrue( - expression.get_descendant_or_error(path.Path(["doc", - "keep_me"])).is_leaf) - - def test_get_sparse_tensors(self): - """Tests get_sparse_tensors on a deep expression.""" - for options in _OPTIONS_TO_TEST: - expression = prensor_test_util.create_nested_prensor() - sparse_tensor_map = prensor._get_sparse_tensors(expression, options) - string_tensor_map = {str(k): v for k, v in sparse_tensor_map.items()} - - self.assertAllEqual(string_tensor_map["doc.bar"].indices, - [[0, 0, 0], [1, 0, 0], [1, 0, 1], [1, 1, 0]]) - self.assertAllEqual(string_tensor_map["doc.bar"].values, - [b"a", b"b", b"c", b"d"]) - self.assertAllEqual(string_tensor_map["doc.keep_me"].indices, - [[0, 0], [1, 0]]) - self.assertAllEqual(string_tensor_map["doc.keep_me"].values, - [False, True]) - - self.assertAllEqual( - string_tensor_map["user.friends"].indices, - [[0, 0, 0], [1, 0, 0], [1, 0, 1], [1, 1, 0], [2, 0, 0]]) - self.assertAllEqual(string_tensor_map["user.friends"].values, - [b"a", b"b", b"c", b"d", b"e"]) - - def test_get_sparse_tensors_simple(self): - """Tests get_sparse_tensors on a deep expression.""" - for options in _OPTIONS_TO_TEST: - expression = prensor_test_util.create_simple_prensor() - sparse_tensor_map = prensor._get_sparse_tensors(expression, options) - string_tensor_map = {str(k): v for k, v in sparse_tensor_map.items()} - self.assertAllEqual(string_tensor_map["foo"].indices, [[0], [1], [2]]) - self.assertAllEqual(string_tensor_map["foo"].dense_shape, [3]) - - self.assertAllEqual(string_tensor_map["foo"].values, [9, 8, 7]) - self.assertAllEqual(string_tensor_map["foorepeated"].indices, - [[0, 0], [1, 0], [1, 1], [2, 0]]) - self.assertAllEqual(string_tensor_map["foorepeated"].values, [9, 8, 7, 6]) - self.assertAllEqual(string_tensor_map["foorepeated"].dense_shape, [3, 2]) - - def test_get_sparse_tensor(self): - expression = prensor_test_util.create_simple_prensor() - sparse_tensor = prensor._get_sparse_tensor(expression, - path.create_path("foo")) - self.assertAllEqual(sparse_tensor.indices, [[0], [1], [2]]) - self.assertAllEqual(sparse_tensor.dense_shape, [3]) - self.assertAllEqual(sparse_tensor.values, [9, 8, 7]) - - def test_get_sparse_tensors_simple_dense(self): - """Tests get_sparse_tensors on a deep expression.""" - for options in _OPTIONS_TO_TEST: - expression = prensor_test_util.create_simple_prensor() - sparse_tensor_map = prensor._get_sparse_tensors(expression, options) - string_tensor_map = { - str(k): tf.sparse.to_dense(v) for k, v in sparse_tensor_map.items() - } - - self.assertAllEqual(string_tensor_map["foo"], [9, 8, 7]) - self.assertAllEqual(string_tensor_map["foorepeated"], - [[9, 0], [8, 7], [6, 0]]) - - - def test_broken_ragged_tensors_no_check(self): - """Make sure that it doesn't crash. The result is undefined.""" - expression = prensor_test_util.create_broken_prensor() - ragged_tensor_map = prensor._get_ragged_tensors( - expression, calculate_options.get_options_with_minimal_checks()) - string_tensor_map = {str(k): v for k, v in ragged_tensor_map.items()} - self.evaluate(string_tensor_map) - - # Okay, need to break this apart to handle the V1/V2 issues. - def test_get_ragged_tensors(self): - """Tests get_ragged_tensors on a deep expression.""" - for options in _OPTIONS_TO_TEST: - expression = prensor_test_util.create_nested_prensor() - ragged_tensor_map = prensor._get_ragged_tensors(expression, options) - string_tensor_map = {str(k): v for k, v in ragged_tensor_map.items()} - string_np_map = self.evaluate(string_tensor_map) - self.assertAllEqual(string_np_map["doc.bar"].to_list(), - [[[b"a"]], [[b"b", b"c"], [b"d"]], []]) - - self.assertAllEqual(string_np_map["doc.keep_me"].to_list(), - [[[False]], [[True], []], []]) - self.assertAllEqual(string_np_map["user.friends"].to_list(), - [[[b"a"]], [[b"b", b"c"], [b"d"]], [[b"e"]]]) - - def test_get_ragged_tensor(self): - """Tests get_ragged_tensor on a deep field.""" - for options in _OPTIONS_TO_TEST: - expression = prensor_test_util.create_nested_prensor() - ragged_tensor = prensor._get_ragged_tensor(expression, - path.create_path("doc.bar"), - options) - self.assertAllEqual(ragged_tensor, [[[b"a"]], [[b"b", b"c"], [b"d"]], []]) # The following are only available post TF 1.14. if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/struct2tensor/prensor_to_structured_tensor.py b/struct2tensor/prensor_to_structured_tensor.py index fd742a8..40cec8b 100644 --- a/struct2tensor/prensor_to_structured_tensor.py +++ b/struct2tensor/prensor_to_structured_tensor.py @@ -31,85 +31,95 @@ import tensorflow as tf from tensorflow.python.ops.ragged.row_partition import ( - RowPartition, # pylint: disable=g-direct-tensorflow-import + RowPartition, # pylint: disable=g-direct-tensorflow-import ) from tensorflow.python.ops.structured import ( - structured_tensor, # pylint: disable=g-direct-tensorflow-import + structured_tensor, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor import path, prensor def prensor_to_structured_tensor( - p: prensor.Prensor) -> structured_tensor.StructuredTensor: - """Creates a structured tensor from a prensor. + p: prensor.Prensor, +) -> structured_tensor.StructuredTensor: + """Creates a structured tensor from a prensor. - All information about optional and repeated fields is dropped. - If the field names in the proto do not meet the specifications for - StructuredTensor, the behavior is undefined. + All information about optional and repeated fields is dropped. + If the field names in the proto do not meet the specifications for + StructuredTensor, the behavior is undefined. - Args: - p: the prensor to convert. + Args: + p: the prensor to convert. - Returns: - An equivalent StructuredTensor. + Returns: + An equivalent StructuredTensor. - Raises: - ValueError: if the root of the prensor is not a RootNodeTensor. - """ - node = p.node - if isinstance(node, prensor.RootNodeTensor): - return _root_node_to_structured_tensor( - _prensor_to_field_map(p.get_children(), node.size)) - raise ValueError("Must be a root prensor") + Raises: + ValueError: if the root of the prensor is not a RootNodeTensor. + """ + node = p.node + if isinstance(node, prensor.RootNodeTensor): + return _root_node_to_structured_tensor( + _prensor_to_field_map(p.get_children(), node.size) + ) + raise ValueError("Must be a root prensor") def _root_node_to_structured_tensor( - fields: Mapping[path.Step, prensor.Prensor] + fields: Mapping[path.Step, prensor.Prensor], ) -> structured_tensor.StructuredTensor: - """Convert a map of prensors to a structured tensor.""" - return structured_tensor.StructuredTensor.from_fields( - fields=fields, shape=tf.TensorShape([None])) + """Convert a map of prensors to a structured tensor.""" + return structured_tensor.StructuredTensor.from_fields( + fields=fields, shape=tf.TensorShape([None]) + ) def _prensor_to_structured_tensor_helper( p: prensor.Prensor, nrows: tf.Tensor ) -> Union[tf.RaggedTensor, structured_tensor.StructuredTensor]: - """Convert a prensor to a structured tensor with a certain number of rows.""" - node = p.node - if isinstance(node, prensor.LeafNodeTensor): - return _leaf_node_to_ragged_tensor(node, nrows) - assert isinstance(node, prensor.ChildNodeTensor) - return _child_node_to_structured_tensor( - node, _prensor_to_field_map(p.get_children(), node.size), nrows) + """Convert a prensor to a structured tensor with a certain number of rows.""" + node = p.node + if isinstance(node, prensor.LeafNodeTensor): + return _leaf_node_to_ragged_tensor(node, nrows) + assert isinstance(node, prensor.ChildNodeTensor) + return _child_node_to_structured_tensor( + node, _prensor_to_field_map(p.get_children(), node.size), nrows + ) def _prensor_to_field_map( - p_fields: Mapping[path.Step, prensor.Prensor], - nrows: tf.Tensor) -> Mapping[path.Step, structured_tensor.StructuredTensor]: - """Convert a map of prensors to map of structured tensors.""" - result = {} - for step, child in p_fields.items(): - try: - result[step] = _prensor_to_structured_tensor_helper(child, nrows) - except ValueError as err: - raise ValueError(f"Error in field: {step}") from err - return result + p_fields: Mapping[path.Step, prensor.Prensor], nrows: tf.Tensor +) -> Mapping[path.Step, structured_tensor.StructuredTensor]: + """Convert a map of prensors to map of structured tensors.""" + result = {} + for step, child in p_fields.items(): + try: + result[step] = _prensor_to_structured_tensor_helper(child, nrows) + except ValueError as err: + raise ValueError(f"Error in field: {step}") from err + return result def _child_node_to_structured_tensor( - node: prensor.ChildNodeTensor, fields: Mapping[path.Step, prensor.Prensor], - nrows: tf.Tensor) -> structured_tensor.StructuredTensor: - """Convert a map of prensors to map of structured tensors.""" - st = structured_tensor.StructuredTensor.from_fields( - fields=fields, shape=tf.TensorShape([None]), nrows=node.size) - row_partition = RowPartition.from_value_rowids( - value_rowids=node.parent_index, nrows=nrows) - return st.partition_outer_dimension(row_partition) - - -def _leaf_node_to_ragged_tensor(node: prensor.LeafNodeTensor, - nrows: tf.Tensor) -> tf.RaggedTensor: - """Converts a LeafNodeTensor to a 2D ragged tensor.""" - return tf.RaggedTensor.from_value_rowids( - values=node.values, value_rowids=node.parent_index, nrows=nrows) + node: prensor.ChildNodeTensor, + fields: Mapping[path.Step, prensor.Prensor], + nrows: tf.Tensor, +) -> structured_tensor.StructuredTensor: + """Convert a map of prensors to map of structured tensors.""" + st = structured_tensor.StructuredTensor.from_fields( + fields=fields, shape=tf.TensorShape([None]), nrows=node.size + ) + row_partition = RowPartition.from_value_rowids( + value_rowids=node.parent_index, nrows=nrows + ) + return st.partition_outer_dimension(row_partition) + + +def _leaf_node_to_ragged_tensor( + node: prensor.LeafNodeTensor, nrows: tf.Tensor +) -> tf.RaggedTensor: + """Converts a LeafNodeTensor to a 2D ragged tensor.""" + return tf.RaggedTensor.from_value_rowids( + values=node.values, value_rowids=node.parent_index, nrows=nrows + ) diff --git a/struct2tensor/prensor_to_structured_tensor_test.py b/struct2tensor/prensor_to_structured_tensor_test.py index 52f0069..264da8d 100644 --- a/struct2tensor/prensor_to_structured_tensor_test.py +++ b/struct2tensor/prensor_to_structured_tensor_test.py @@ -23,84 +23,94 @@ # @test_util.run_all_in_graph_and_eager_modes class PrensorToStructuredTensorTest(tf.test.TestCase): + def test_simple_prensor(self): + pren = prensor_test_util.create_simple_prensor() + st = prensor_to_structured_tensor.prensor_to_structured_tensor(pren) + self.assertAllEqual(st.field_value("foo"), [[9], [8], [7]]) + self.assertAllEqual(st.field_value("foorepeated"), [[9], [8, 7], [6]]) - def test_simple_prensor(self): - pren = prensor_test_util.create_simple_prensor() - st = prensor_to_structured_tensor.prensor_to_structured_tensor(pren) - self.assertAllEqual(st.field_value("foo"), [[9], [8], [7]]) - self.assertAllEqual(st.field_value("foorepeated"), [[9], [8, 7], [6]]) + def test_nested_prensor(self): + """Tests on a deep expression.""" + pren = prensor_test_util.create_nested_prensor() + st = prensor_to_structured_tensor.prensor_to_structured_tensor(pren) + self.assertAllEqual( + st.field_value(["doc", "bar"]), [[[b"a"]], [[b"b", b"c"], [b"d"]], []] + ) + self.assertAllEqual( + st.field_value(["doc", "keep_me"]), [[[False]], [[True], []], []] + ) + self.assertAllEqual( + st.field_value(["user", "friends"]), + [[[b"a"]], [[b"b", b"c"], [b"d"]], [[b"e"]]], + ) - def test_nested_prensor(self): - """Tests on a deep expression.""" - pren = prensor_test_util.create_nested_prensor() - st = prensor_to_structured_tensor.prensor_to_structured_tensor(pren) - self.assertAllEqual( - st.field_value(["doc", "bar"]), [[[b"a"]], [[b"b", b"c"], [b"d"]], []]) - self.assertAllEqual( - st.field_value(["doc", "keep_me"]), [[[False]], [[True], []], []]) - self.assertAllEqual( - st.field_value(["user", "friends"]), - [[[b"a"]], [[b"b", b"c"], [b"d"]], [[b"e"]]]) + def test_big_prensor(self): + """Test the big prensor. - def test_big_prensor(self): - """Test the big prensor. + a prensor expression representing: + {foo:9, foorepeated:[9], doc:[{bar:["a"], keep_me:False}], + user:[{friends:["a"]}]} + {foo:8, foorepeated:[8,7], + doc:[{bar:["b","c"],keep_me:True},{bar:["d"]}], + user:[{friends:["b", "c"]},{friends:["d"]}],} + {foo:7, foorepeated:[6], user:[friends:["e"]]} + """ + pren = prensor_test_util.create_big_prensor() + st = prensor_to_structured_tensor.prensor_to_structured_tensor(pren) + self.assertAllEqual(st.field_value("foo"), [[9], [8], [7]]) + self.assertAllEqual(st.field_value("foorepeated"), [[9], [8, 7], [6]]) + self.assertAllEqual( + st.field_value(["doc", "keep_me"]), [[[False]], [[True], []], []] + ) + self.assertAllEqual( + st.field_value(["user", "friends"]), + [[[b"a"]], [[b"b", b"c"], [b"d"]], [[b"e"]]], + ) + self.assertAllEqual( + st.field_value(["doc", "bar"]), [[[b"a"]], [[b"b", b"c"], [b"d"]], []] + ) - a prensor expression representing: - {foo:9, foorepeated:[9], doc:[{bar:["a"], keep_me:False}], - user:[{friends:["a"]}]} - {foo:8, foorepeated:[8,7], - doc:[{bar:["b","c"],keep_me:True},{bar:["d"]}], - user:[{friends:["b", "c"]},{friends:["d"]}],} - {foo:7, foorepeated:[6], user:[friends:["e"]]} - """ - pren = prensor_test_util.create_big_prensor() - st = prensor_to_structured_tensor.prensor_to_structured_tensor(pren) - self.assertAllEqual(st.field_value("foo"), [[9], [8], [7]]) - self.assertAllEqual(st.field_value("foorepeated"), [[9], [8, 7], [6]]) - self.assertAllEqual( - st.field_value(["doc", "keep_me"]), [[[False]], [[True], []], []]) - self.assertAllEqual( - st.field_value(["user", "friends"]), - [[[b"a"]], [[b"b", b"c"], [b"d"]], [[b"e"]]]) - self.assertAllEqual( - st.field_value(["doc", "bar"]), [[[b"a"]], [[b"b", b"c"], [b"d"]], []]) + def test_deep_prensor(self): + """Test a prensor with three layers: root, event, and doc. - def test_deep_prensor(self): - """Test a prensor with three layers: root, event, and doc. + a prensor expression representing: + {foo:9, foorepeated:[9], user:[{friends:["a"]}], + event:{doc:[{bar:["a"], keep_me:False}]}} + {foo:8, foorepeated:[8,7], + event:{doc:[{bar:["b","c"], keep_me:True},{bar:["d"]}]}, + user:[{friends:["b", "c"]}, {friends:["d"]}]} + {foo:7, foorepeated:[6], user:[friends:["e"]], event:{}} + """ + pren = prensor_test_util.create_deep_prensor() + st = prensor_to_structured_tensor.prensor_to_structured_tensor(pren) + self.assertAllEqual(st.field_value("foo"), [[9], [8], [7]]) + self.assertAllEqual(st.field_value(["foorepeated"]), [[9], [8, 7], [6]]) + self.assertAllEqual( + st.field_value(["user", "friends"]), + [[[b"a"]], [[b"b", b"c"], [b"d"]], [[b"e"]]], + ) + self.assertAllEqual( + st.field_value(["event", "doc", "bar"]), + [[[[b"a"]]], [[[b"b", b"c"], [b"d"]]], [[]]], + ) + self.assertAllEqual( + st.field_value(["event", "doc", "keep_me"]), + [[[[False]]], [[[True], []]], [[]]], + ) - a prensor expression representing: - {foo:9, foorepeated:[9], user:[{friends:["a"]}], - event:{doc:[{bar:["a"], keep_me:False}]}} - {foo:8, foorepeated:[8,7], - event:{doc:[{bar:["b","c"], keep_me:True},{bar:["d"]}]}, - user:[{friends:["b", "c"]}, {friends:["d"]}]} - {foo:7, foorepeated:[6], user:[friends:["e"]], event:{}} - """ - pren = prensor_test_util.create_deep_prensor() - st = prensor_to_structured_tensor.prensor_to_structured_tensor(pren) - self.assertAllEqual(st.field_value("foo"), [[9], [8], [7]]) - self.assertAllEqual(st.field_value(["foorepeated"]), [[9], [8, 7], [6]]) - self.assertAllEqual( - st.field_value(["user", "friends"]), - [[[b"a"]], [[b"b", b"c"], [b"d"]], [[b"e"]]]) - self.assertAllEqual( - st.field_value(["event", "doc", "bar"]), - [[[[b"a"]]], [[[b"b", b"c"], [b"d"]]], [[]]]) - self.assertAllEqual( - st.field_value(["event", "doc", "keep_me"]), - [[[[False]]], [[[True], []]], [[]]]) + def test_non_root_prensor(self): + child_prensor = prensor.create_prensor_from_root_and_children( + prensor_test_util.create_child_node([0, 0, 1, 3, 7], True), {} + ) + with self.assertRaisesRegex(ValueError, "Must be a root prensor"): + prensor_to_structured_tensor.prensor_to_structured_tensor(child_prensor) - def test_non_root_prensor(self): - child_prensor = prensor.create_prensor_from_root_and_children( - prensor_test_util.create_child_node([0, 0, 1, 3, 7], True), {}) - with self.assertRaisesRegex(ValueError, "Must be a root prensor"): - prensor_to_structured_tensor.prensor_to_structured_tensor(child_prensor) - - def test_e2e_proto(self): - """Integration test for parsing protobufs.""" - serialized = tf.constant([ - text_format.Merge( - """ + def test_e2e_proto(self): + """Integration test for parsing protobufs.""" + serialized = tf.constant( + [ + text_format.Merge( + """ session_info { session_duration_sec: 1.0 session_feature: "foo" @@ -122,15 +132,19 @@ def test_e2e_proto(self): number_of_views: 3 } } - """, test_pb2.Session()).SerializeToString() - ]) - expr = proto.create_expression_from_proto(serialized, - test_pb2.Session().DESCRIPTOR) - [p] = calculate.calculate_prensors([expr]) - print(p) - st = prensor_to_structured_tensor.prensor_to_structured_tensor(p) - print(st) + """, + test_pb2.Session(), + ).SerializeToString() + ] + ) + expr = proto.create_expression_from_proto( + serialized, test_pb2.Session().DESCRIPTOR + ) + [p] = calculate.calculate_prensors([expr]) + print(p) + st = prensor_to_structured_tensor.prensor_to_structured_tensor(p) + print(st) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/struct2tensor/prensor_util.py b/struct2tensor/prensor_util.py index 6e2abb6..453173f 100644 --- a/struct2tensor/prensor_util.py +++ b/struct2tensor/prensor_util.py @@ -22,7 +22,7 @@ import tensorflow as tf from tensorflow.python.util import ( - deprecation, # pylint: disable=g-direct-tensorflow-import + deprecation, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor import calculate_options, path, prensor @@ -32,80 +32,76 @@ def get_sparse_tensor( t: prensor.Prensor, p: path.Path, - options: calculate_options.Options = calculate_options.get_default_options( - ) + options: calculate_options.Options = calculate_options.get_default_options(), ) -> tf.SparseTensor: - """Gets a sparse tensor for path p. + """Gets a sparse tensor for path p. - Note that any optional fields are not registered as dimensions, as they can't - be represented in a sparse tensor. + Note that any optional fields are not registered as dimensions, as they can't + be represented in a sparse tensor. - Args: - t: The Prensor to extract tensors from. - p: The path to a leaf node in `t`. - options: Currently unused. + Args: + t: The Prensor to extract tensors from. + p: The path to a leaf node in `t`. + options: Currently unused. - Returns: - A sparse tensor containing values of the leaf node, preserving the - structure along the path. Raises an error if the path is not found. - """ - return t.get_sparse_tensor(p, options) + Returns: + A sparse tensor containing values of the leaf node, preserving the + structure along the path. Raises an error if the path is not found. + """ + return t.get_sparse_tensor(p, options) @deprecation.deprecated(None, "Use the Prensor class method instead.") def get_sparse_tensors( t: prensor.Prensor, - options: calculate_options.Options = calculate_options.get_default_options( - ) + options: calculate_options.Options = calculate_options.get_default_options(), ) -> Mapping[path.Path, tf.SparseTensor]: - """Gets sparse tensors for all the leaves of the prensor expression. + """Gets sparse tensors for all the leaves of the prensor expression. - Args: - t: The Prensor to extract tensors from. - options: Currently unused. + Args: + t: The Prensor to extract tensors from. + options: Currently unused. - Returns: - A map from paths to sparse tensors. - """ - return t.get_sparse_tensors(options) + Returns: + A map from paths to sparse tensors. + """ + return t.get_sparse_tensors(options) @deprecation.deprecated(None, "Use the Prensor class method instead.") def get_ragged_tensor( t: prensor.Prensor, p: path.Path, - options: calculate_options.Options = calculate_options.get_default_options( - ) + options: calculate_options.Options = calculate_options.get_default_options(), ) -> tf.RaggedTensor: - """Get a ragged tensor for a path. + """Get a ragged tensor for a path. - All steps are represented in the ragged tensor. + All steps are represented in the ragged tensor. - Args: - t: The Prensor to extract tensors from. - p: the path to a leaf node in `t`. - options: used to pass options for calculating ragged tensors. + Args: + t: The Prensor to extract tensors from. + p: the path to a leaf node in `t`. + options: used to pass options for calculating ragged tensors. - Returns: - A ragged tensor containing values of the leaf node, preserving the - structure along the path. Raises an error if the path is not found. - """ - return t.get_ragged_tensor(p, options) + Returns: + A ragged tensor containing values of the leaf node, preserving the + structure along the path. Raises an error if the path is not found. + """ + return t.get_ragged_tensor(p, options) @deprecation.deprecated(None, "Use the Prensor class method instead.") def get_ragged_tensors( t: prensor.Prensor, - options: calculate_options.Options = calculate_options.get_default_options( - ) + options: calculate_options.Options = calculate_options.get_default_options(), ) -> Mapping[path.Path, tf.RaggedTensor]: - """Gets ragged tensors for all the leaves of the prensor expression. + """Gets ragged tensors for all the leaves of the prensor expression. - Args: - t: The Prensor to extract tensors from. - options: used to pass options for calculating ragged tensors. + Args: + t: The Prensor to extract tensors from. + options: used to pass options for calculating ragged tensors. - Returns: - A map from paths to ragged tensors. - """ - return t.get_ragged_tensors(options) + Returns: + A map from paths to ragged tensors. + """ + return t.get_ragged_tensors(options) diff --git a/struct2tensor/prensor_value.py b/struct2tensor/prensor_value.py index eb57ef8..6db1953 100644 --- a/struct2tensor/prensor_value.py +++ b/struct2tensor/prensor_value.py @@ -30,271 +30,279 @@ import numpy as np import tensorflow as tf from tensorflow.python.client import ( - session as session_lib, # pylint: disable=g-direct-tensorflow-import + session as session_lib, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor import path, prensor class RootNodeValue: - """The value of the root.""" + """The value of the root.""" - __slots__ = ["_size"] + __slots__ = ["_size"] - def __init__(self, size: np.int64): - """Creates a root node. + def __init__(self, size: np.int64): + """Creates a root node. - Args: - size: how many root objects there are. - """ - self._size = size + Args: + size: how many root objects there are. + """ + self._size = size - @property - def size(self): - return self._size + @property + def size(self): + return self._size - @property - def is_repeated(self): - return True + @property + def is_repeated(self): + return True - def schema_string(self): - return "repeated" + def schema_string(self): + return "repeated" - def data_string(self): - return f"size: {self._size}" + def data_string(self): + return f"size: {self._size}" - def __str__(self): - return "RootNode" + def __str__(self): + return "RootNode" class ChildNodeValue: - """The value of an intermediate node.""" + """The value of an intermediate node.""" - __slots__ = ["_parent_index", "_is_repeated"] + __slots__ = ["_parent_index", "_is_repeated"] - def __init__(self, parent_index: np.ndarray, is_repeated: bool): - """Creates a child node. + def __init__(self, parent_index: np.ndarray, is_repeated: bool): + """Creates a child node. - Args: - parent_index: a 1-D int64 ndarray where parent_index[i] represents the - parent index of the ith child. - is_repeated: a bool indicating if there can be more than one child per - parent. - """ - self._parent_index = parent_index - self._is_repeated = is_repeated + Args: + parent_index: a 1-D int64 ndarray where parent_index[i] represents the + parent index of the ith child. + is_repeated: a bool indicating if there can be more than one child per + parent. + """ + self._parent_index = parent_index + self._is_repeated = is_repeated - @property - def size(self): - """Returns the size, as if this was the root prensor. + @property + def size(self): + """Returns the size, as if this was the root prensor. - Returns: - A 1-D ndarray of size 1. - """ - return tf.shape(self.parent_index, out_type=tf.int64) + Returns: + A 1-D ndarray of size 1. + """ + return tf.shape(self.parent_index, out_type=tf.int64) - @property - def parent_index(self): - return self._parent_index + @property + def parent_index(self): + return self._parent_index - @property - def is_repeated(self): - return self._is_repeated + @property + def is_repeated(self): + return self._is_repeated - def schema_string(self) -> str: - return "repeated" if self.is_repeated else "optional" + def schema_string(self) -> str: + return "repeated" if self.is_repeated else "optional" - def data_string(self): - return f"parent_index: {self._parent_index}" + def data_string(self): + return f"parent_index: {self._parent_index}" - def __str__(self): - return f"ChildNode {self.schema_string()} {self.data_string()}" + def __str__(self): + return f"ChildNode {self.schema_string()} {self.data_string()}" class LeafNodeValue: - """The value of a leaf node.""" + """The value of a leaf node.""" - __slots__ = ["_parent_index", "_values", "_is_repeated"] + __slots__ = ["_parent_index", "_values", "_is_repeated"] - def __init__(self, parent_index: np.ndarray, values: np.ndarray, - is_repeated: bool): - """Creates a leaf node. + def __init__(self, parent_index: np.ndarray, values: np.ndarray, is_repeated: bool): + """Creates a leaf node. - Args: - parent_index: a 1-D int64 ndarray where parent_index[i] represents the - parent index of values[i] - values: a 1-D ndarray of equal length to parent_index. - is_repeated: a bool indicating if there can be more than one child per - parent. - """ - self._parent_index = parent_index - self._values = values - self._is_repeated = is_repeated + Args: + parent_index: a 1-D int64 ndarray where parent_index[i] represents the + parent index of values[i] + values: a 1-D ndarray of equal length to parent_index. + is_repeated: a bool indicating if there can be more than one child per + parent. + """ + self._parent_index = parent_index + self._values = values + self._is_repeated = is_repeated - @property - def parent_index(self): - return self._parent_index + @property + def parent_index(self): + return self._parent_index - @property - def is_repeated(self): - return self._is_repeated + @property + def is_repeated(self): + return self._is_repeated - @property - def values(self): - return self._values + @property + def values(self): + return self._values - def data_string(self): - return f"parent_index: {self._parent_index} values: {self._values}" + def data_string(self): + return f"parent_index: {self._parent_index} values: {self._values}" - def schema_string(self) -> str: - return "{} {}".format("repeated" if self.is_repeated else "optional", - str(self.values.dtype)) + def schema_string(self) -> str: + return "{} {}".format( + "repeated" if self.is_repeated else "optional", str(self.values.dtype) + ) - def __str__(self): - return "{} {}".format("repeated" if self.is_repeated else "optional", - str(self.values.dtype)) + def __str__(self): + return "{} {}".format( + "repeated" if self.is_repeated else "optional", str(self.values.dtype) + ) NodeValue = Union[RootNodeValue, ChildNodeValue, LeafNodeValue] # pylint: disable=invalid-name class PrensorValue: - """A tree of NodeValue objects.""" - - __slots__ = ["_node", "_children"] - - def __init__(self, node: NodeValue, - children: "collections.OrderedDict[path.Step, PrensorValue]"): - """Construct a PrensorValue. - - Do not call directly, instead call materialize(...) below. - - Args: - node: the NodeValue of the root. - children: a map from edge to subtree. - """ - self._node = node - self._children = children - - # TODO(martinz): This could be Value. - @property - def node(self) -> NodeValue: - """The node of the root of the subtree.""" - return self._node - - def get_child(self, field_name: path.Step) -> Optional["PrensorValue"]: - """Gets the child at field_name.""" - return self._children.get(field_name) - - def is_leaf(self) -> bool: - """True iff the node value is a LeafNodeValue.""" - return isinstance(self._node, LeafNodeValue) - - def get_child_or_error(self, field_name: path.Step) -> "PrensorValue": - """Gets the child at field_name.""" - result = self._children.get(field_name) - if result is not None: - return result - raise ValueError(f"Field not found: {str(field_name)}") - - def get_descendant(self, p: path.Path) -> Optional["PrensorValue"]: - """Finds the descendant at the path.""" - result = self - for field_name in p.field_list: - result = result.get_child(field_name) - if result is None: - return None - return result - - def get_descendant_or_error(self, p: path.Path) -> "PrensorValue": - """Finds the descendant at the path.""" - result = self.get_descendant(p) - if result is None: - raise ValueError(f"Missing path: {str(p)}") - return result - - def get_children(self) -> Mapping[path.Step, "PrensorValue"]: - """A map from field name to subtree.""" - return self._children - - def get_descendants(self) -> Mapping[path.Path, "PrensorValue"]: - """A map from paths to all subtrees.""" - result = {path.Path([]): self} - for k, v in self._children.items(): - subtree_descendants = v.get_descendants() - for k2, v2 in subtree_descendants.items(): - result[path.Path([k]).concat(k2)] = v2 - return result - - def field_names(self) -> FrozenSet[path.Step]: - """Returns the field names of the children.""" - return frozenset(self._children.keys()) - - def _string_helper(self, field_name: str) -> Sequence[str]: - """Helper for __str__ that outputs a list of lines.""" - result = [ - f"{self.node.schema_string()} {str(field_name)} {self.node.data_string()}" - ] - for k, v in self._children.items(): - recursive = v._string_helper(k) # pylint: disable=protected-access - result.extend([f" {x}" for x in recursive]) - return result - - def _schema_string_helper(self, field_name: str) -> Sequence[str]: - """Helper for __str__ that outputs a list of lines.""" - result = [f"{self.node.schema_string()} {str(field_name)}"] - for k, v in self._children.items(): - recursive = v._string_helper(k) # pylint: disable=protected-access - result.extend([f" {x}" for x in recursive]) - return result - - def schema_string(self): - """Returns a string representing the schema of the Prensor.""" - return "\n".join(self._schema_string_helper("")) - - def __str__(self): - """Returns a string representing the schema of the Prensor.""" - return "\n".join(self._string_helper("")) + """A tree of NodeValue objects.""" + + __slots__ = ["_node", "_children"] + + def __init__( + self, + node: NodeValue, + children: "collections.OrderedDict[path.Step, PrensorValue]", + ): + """Construct a PrensorValue. + + Do not call directly, instead call materialize(...) below. + + Args: + node: the NodeValue of the root. + children: a map from edge to subtree. + """ + self._node = node + self._children = children + + # TODO(martinz): This could be Value. + @property + def node(self) -> NodeValue: + """The node of the root of the subtree.""" + return self._node + + def get_child(self, field_name: path.Step) -> Optional["PrensorValue"]: + """Gets the child at field_name.""" + return self._children.get(field_name) + + def is_leaf(self) -> bool: + """True iff the node value is a LeafNodeValue.""" + return isinstance(self._node, LeafNodeValue) + + def get_child_or_error(self, field_name: path.Step) -> "PrensorValue": + """Gets the child at field_name.""" + result = self._children.get(field_name) + if result is not None: + return result + raise ValueError(f"Field not found: {str(field_name)}") + + def get_descendant(self, p: path.Path) -> Optional["PrensorValue"]: + """Finds the descendant at the path.""" + result = self + for field_name in p.field_list: + result = result.get_child(field_name) + if result is None: + return None + return result + + def get_descendant_or_error(self, p: path.Path) -> "PrensorValue": + """Finds the descendant at the path.""" + result = self.get_descendant(p) + if result is None: + raise ValueError(f"Missing path: {str(p)}") + return result + + def get_children(self) -> Mapping[path.Step, "PrensorValue"]: + """A map from field name to subtree.""" + return self._children + + def get_descendants(self) -> Mapping[path.Path, "PrensorValue"]: + """A map from paths to all subtrees.""" + result = {path.Path([]): self} + for k, v in self._children.items(): + subtree_descendants = v.get_descendants() + for k2, v2 in subtree_descendants.items(): + result[path.Path([k]).concat(k2)] = v2 + return result + + def field_names(self) -> FrozenSet[path.Step]: + """Returns the field names of the children.""" + return frozenset(self._children.keys()) + + def _string_helper(self, field_name: str) -> Sequence[str]: + """Helper for __str__ that outputs a list of lines.""" + result = [ + f"{self.node.schema_string()} {str(field_name)} {self.node.data_string()}" + ] + for k, v in self._children.items(): + recursive = v._string_helper(k) # pylint: disable=protected-access + result.extend([f" {x}" for x in recursive]) + return result + + def _schema_string_helper(self, field_name: str) -> Sequence[str]: + """Helper for __str__ that outputs a list of lines.""" + result = [f"{self.node.schema_string()} {str(field_name)}"] + for k, v in self._children.items(): + recursive = v._string_helper(k) # pylint: disable=protected-access + result.extend([f" {x}" for x in recursive]) + return result + + def schema_string(self): + """Returns a string representing the schema of the Prensor.""" + return "\n".join(self._schema_string_helper("")) + + def __str__(self): + """Returns a string representing the schema of the Prensor.""" + return "\n".join(self._string_helper("")) def _prensor_value_from_type_spec_and_component_values( prensor_type_spec: prensor._PrensorTypeSpec, - component_values: Iterator[Union[int, np.ndarray]]) -> PrensorValue: - """Creates a PrensorValue from a _PrensorTypeSpec and components.""" - # pylint: disable=protected-access - if prensor_type_spec._node_type == prensor_type_spec._NodeType.ROOT: - size = next(component_values) - assert isinstance(size, (int, np.integer)), type(size) - node = RootNodeValue(np.int64(size)) - elif prensor_type_spec._node_type == prensor_type_spec._NodeType.CHILD: - node = ChildNodeValue(next(component_values), - prensor_type_spec._is_repeated) - else: - parent_index = next(component_values) - values = next(component_values) - node = LeafNodeValue(parent_index, values, prensor_type_spec._is_repeated) - - step_to_child = collections.OrderedDict() - for step, child_spec in prensor_type_spec._children_specs: - step_to_child[step] = _prensor_value_from_type_spec_and_component_values( - child_spec, component_values) - return PrensorValue(node, step_to_child) + component_values: Iterator[Union[int, np.ndarray]], +) -> PrensorValue: + """Creates a PrensorValue from a _PrensorTypeSpec and components.""" + # pylint: disable=protected-access + if prensor_type_spec._node_type == prensor_type_spec._NodeType.ROOT: + size = next(component_values) + assert isinstance(size, (int, np.integer)), type(size) + node = RootNodeValue(np.int64(size)) + elif prensor_type_spec._node_type == prensor_type_spec._NodeType.CHILD: + node = ChildNodeValue(next(component_values), prensor_type_spec._is_repeated) + else: + parent_index = next(component_values) + values = next(component_values) + node = LeafNodeValue(parent_index, values, prensor_type_spec._is_repeated) + + step_to_child = collections.OrderedDict() + for step, child_spec in prensor_type_spec._children_specs: + step_to_child[step] = _prensor_value_from_type_spec_and_component_values( + child_spec, component_values + ) + return PrensorValue(node, step_to_child) def _prensor_value_fetch(prensor_tree: prensor.Prensor): - """Fetch function for PrensorValue. See the document in session_lib.""" - # pylint: disable=protected-access - type_spec = prensor_tree._type_spec - components = type_spec._to_components(prensor_tree) - def _construct_prensor_value(component_values): - return _prensor_value_from_type_spec_and_component_values( - type_spec, iter(component_values)) + """Fetch function for PrensorValue. See the document in session_lib.""" + # pylint: disable=protected-access + type_spec = prensor_tree._type_spec + components = type_spec._to_components(prensor_tree) - return components, _construct_prensor_value + def _construct_prensor_value(component_values): + return _prensor_value_from_type_spec_and_component_values( + type_spec, iter(component_values) + ) + + return components, _construct_prensor_value session_lib.register_session_run_conversion_functions( prensor.Prensor, _prensor_value_fetch, feed_function=None, - feed_function_for_partial_run=None) + feed_function_for_partial_run=None, +) diff --git a/struct2tensor/prensor_value_test.py b/struct2tensor/prensor_value_test.py index 2b6aae3..693a220 100644 --- a/struct2tensor/prensor_value_test.py +++ b/struct2tensor/prensor_value_test.py @@ -16,71 +16,86 @@ import tensorflow as tf from struct2tensor import ( - path, - prensor, # pylint: disable=unused-import + path, + prensor, # pylint: disable=unused-import ) from struct2tensor.test import prensor_test_util # This is a V1 test. The prensor value is not meaningful for V2. class PrensorValueTest(tf.test.TestCase): + def test_session_run(self): + """Tests get_sparse_tensors on a deep expression.""" + if tf.executing_eagerly(): + return - def test_session_run(self): - """Tests get_sparse_tensors on a deep expression.""" - if tf.executing_eagerly(): - return + with self.cached_session(use_gpu=False) as sess: + pren = prensor_test_util.create_nested_prensor() + mat = sess.run(pren) + self.assertAllEqual( + mat.get_descendant_or_error(path.Path(["doc", "bar"])).node.values, + [b"a", b"b", b"c", b"d"], + ) + self.assertAllEqual( + mat.get_descendant_or_error( + path.Path(["doc", "bar"]) + ).node.parent_index, + [0, 1, 1, 2], + ) + self.assertAllEqual( + mat.get_descendant_or_error(path.Path(["doc"])).node.parent_index, + [0, 1, 1], + ) + self.assertAllEqual(mat.get_descendant_or_error(path.Path([])).node.size, 3) + self.assertAllEqual( + mat.get_descendant_or_error( + path.Path(["doc", "keep_me"]) + ).node.parent_index, + [0, 1], + ) + self.assertAllEqual( + mat.get_descendant_or_error(path.Path(["doc", "keep_me"])).node.values, + [False, True], + ) - with self.cached_session(use_gpu=False) as sess: - pren = prensor_test_util.create_nested_prensor() - mat = sess.run(pren) - self.assertAllEqual( - mat.get_descendant_or_error(path.Path(["doc", "bar"])).node.values, - [b"a", b"b", b"c", b"d"]) - self.assertAllEqual( - mat.get_descendant_or_error(path.Path( - ["doc", "bar"])).node.parent_index, [0, 1, 1, 2]) - self.assertAllEqual( - mat.get_descendant_or_error(path.Path(["doc"])).node.parent_index, - [0, 1, 1]) - self.assertAllEqual( - mat.get_descendant_or_error(path.Path([])).node.size, 3) - self.assertAllEqual( - mat.get_descendant_or_error(path.Path( - ["doc", "keep_me"])).node.parent_index, [0, 1]) - self.assertAllEqual( - mat.get_descendant_or_error(path.Path( - ["doc", "keep_me"])).node.values, [False, True]) + def test_children_order(self): + # Different evaluations of the same prensor object should result in the same + # prensor value objects. + if tf.executing_eagerly(): + return - def test_children_order(self): - # Different evaluations of the same prensor object should result in the same - # prensor value objects. - if tf.executing_eagerly(): - return - def _check_children(pv): - self.assertEqual(sorted(pv.get_children().keys()), - list(pv.get_children().keys())) - for child in pv.get_children().values(): - _check_children(child) + def _check_children(pv): + self.assertEqual( + sorted(pv.get_children().keys()), list(pv.get_children().keys()) + ) + for child in pv.get_children().values(): + _check_children(child) - with self.cached_session(use_gpu=False) as sess: - p = prensor_test_util.create_nested_prensor() - _check_children(sess.run(p)) + with self.cached_session(use_gpu=False) as sess: + p = prensor_test_util.create_nested_prensor() + _check_children(sess.run(p)) - with self.cached_session(use_gpu=False) as sess: - p = prensor.create_prensor_from_descendant_nodes({ - path.Path([]): prensor_test_util.create_root_node(1), - path.Path(["d"]): prensor_test_util.create_optional_leaf_node( - [0], [True]), - path.Path(["c"]): prensor_test_util.create_optional_leaf_node( - [0], [True]), - path.Path(["b"]): prensor_test_util.create_optional_leaf_node( - [0], [True]), - path.Path(["a"]): prensor_test_util.create_optional_leaf_node( - [0], [True]), - }) - pv = sess.run(p) - self.assertEqual(["a", "b", "c", "d"], list(pv.get_children().keys())) + with self.cached_session(use_gpu=False) as sess: + p = prensor.create_prensor_from_descendant_nodes( + { + path.Path([]): prensor_test_util.create_root_node(1), + path.Path(["d"]): prensor_test_util.create_optional_leaf_node( + [0], [True] + ), + path.Path(["c"]): prensor_test_util.create_optional_leaf_node( + [0], [True] + ), + path.Path(["b"]): prensor_test_util.create_optional_leaf_node( + [0], [True] + ), + path.Path(["a"]): prensor_test_util.create_optional_leaf_node( + [0], [True] + ), + } + ) + pv = sess.run(p) + self.assertEqual(["a", "b", "c", "d"], list(pv.get_children().keys())) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/struct2tensor/struct2tensor_module_test.py b/struct2tensor/struct2tensor_module_test.py index 7fda60f..3a4a170 100644 --- a/struct2tensor/struct2tensor_module_test.py +++ b/struct2tensor/struct2tensor_module_test.py @@ -19,54 +19,66 @@ class Struct2tensorModuleTest(absltest.TestCase): + def test_importing_struct2tensor_modules(self): + """This tests that the exposed packages in root __init__.py are found.""" + # pylint: disable=pointless-statement - def test_importing_struct2tensor_modules(self): - """This tests that the exposed packages in root __init__.py are found.""" - # pylint: disable=pointless-statement + # calculate APIs + s2t.calculate_prensors + s2t.calculate_prensors_with_graph + s2t.get_default_options + s2t.get_options_with_minimal_checks + s2t.calculate_prensors_with_source_paths - # calculate APIs - s2t.calculate_prensors - s2t.calculate_prensors_with_graph - s2t.get_default_options - s2t.get_options_with_minimal_checks - s2t.calculate_prensors_with_source_paths + # expression APIs + s2t.create_expression_from_prensor + s2t.create_expression_from_file_descriptor_set + s2t.create_expression_from_proto + s2t.Expression - # expression APIs - s2t.create_expression_from_prensor - s2t.create_expression_from_file_descriptor_set - s2t.create_expression_from_proto - s2t.Expression + # path API + s2t.create_path + s2t.Path + s2t.Step - # path API - s2t.create_path - s2t.Path - s2t.Step + # prensor APIs + s2t.ChildNodeTensor + s2t.LeafNodeTensor + s2t.NodeTensor + s2t.Prensor + s2t.RootNodeTensor + s2t.create_prensor_from_descendant_nodes + s2t.create_prensor_from_root_and_children + s2t.prensor_value + # pylint: enable=pointless-statement - # prensor APIs - s2t.ChildNodeTensor - s2t.LeafNodeTensor - s2t.NodeTensor - s2t.Prensor - s2t.RootNodeTensor - s2t.create_prensor_from_descendant_nodes - s2t.create_prensor_from_root_and_children - s2t.prensor_value - # pylint: enable=pointless-statement + def test_importing_expression_impl_modules(self): + """This tests that the expression_impl/__init__.py imports are found.""" + from struct2tensor import expression_impl # pylint: disable=g-import-not-at-top - def test_importing_expression_impl_modules(self): - """This tests that the expression_impl/__init__.py imports are found.""" - from struct2tensor import expression_impl # pylint: disable=g-import-not-at-top + modules = [ + "apply_schema", + "broadcast", + "depth_limit", + "filter_expression", + "index", + "map_prensor", + "map_prensor_to_prensor", + "map_values", + "parquet", + "placeholder", + "project", + "promote", + "promote_and_broadcast", + "proto", + "reroot", + "size", + "slice_expression", + ] - modules = [ - 'apply_schema', 'broadcast', 'depth_limit', 'filter_expression', - 'index', 'map_prensor', 'map_prensor_to_prensor', 'map_values', - 'parquet', 'placeholder', 'project', 'promote', 'promote_and_broadcast', - 'proto', 'reroot', 'size', 'slice_expression' - ] + for module in modules: + getattr(expression_impl, module) - for module in modules: - getattr(expression_impl, module) - -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/struct2tensor/structured_tensor_to_prensor.py b/struct2tensor/structured_tensor_to_prensor.py index d607e7a..f7dfa0b 100644 --- a/struct2tensor/structured_tensor_to_prensor.py +++ b/struct2tensor/structured_tensor_to_prensor.py @@ -43,263 +43,278 @@ import tensorflow.compat.v2 as tf from tensorflow.python.ops.ragged.row_partition import ( - RowPartition, # pylint: disable=g-direct-tensorflow-import + RowPartition, # pylint: disable=g-direct-tensorflow-import ) from tensorflow.python.ops.structured import ( - structured_tensor, # pylint: disable=g-direct-tensorflow-import + structured_tensor, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor import path, prensor def structured_tensor_to_prensor( - st: structured_tensor.StructuredTensor, - default_field_name: path.Step = "data") -> prensor.Prensor: - """Converts a structured tensor to a prensor. - - Certain rank information must be known. For more details about the - transformation, see the notes above. - - Args: - st: the structured tensor to convert. - default_field_name: the name to use when there is an unnamed dimension. - - Returns: - a logically equivalent Prensor. - - Raises: - ValueError: if there is an issue with the structured tensor. - """ - row_partitions = st.row_partitions - if len(row_partitions) >= 1: - child_prensor = _structured_tensor_to_child_prensor(st, default_field_name) - return prensor.create_prensor_from_root_and_children( - prensor.RootNodeTensor((st).nrows()), - {default_field_name: child_prensor}) - if st.rank == 1: - return prensor.create_prensor_from_root_and_children( - prensor.RootNodeTensor((st).nrows()), - _structured_tensor_prensor_map(st, default_field_name)) - # st is a scalar StructuredTensor. - return structured_tensor_to_prensor(_expand_dims(st, 0), default_field_name) + st: structured_tensor.StructuredTensor, default_field_name: path.Step = "data" +) -> prensor.Prensor: + """Converts a structured tensor to a prensor. + + Certain rank information must be known. For more details about the + transformation, see the notes above. + + Args: + st: the structured tensor to convert. + default_field_name: the name to use when there is an unnamed dimension. + + Returns: + a logically equivalent Prensor. + + Raises: + ValueError: if there is an issue with the structured tensor. + """ + row_partitions = st.row_partitions + if len(row_partitions) >= 1: + child_prensor = _structured_tensor_to_child_prensor(st, default_field_name) + return prensor.create_prensor_from_root_and_children( + prensor.RootNodeTensor((st).nrows()), {default_field_name: child_prensor} + ) + if st.rank == 1: + return prensor.create_prensor_from_root_and_children( + prensor.RootNodeTensor((st).nrows()), + _structured_tensor_prensor_map(st, default_field_name), + ) + # st is a scalar StructuredTensor. + return structured_tensor_to_prensor(_expand_dims(st, 0), default_field_name) def _structured_tensor_prensor_map( - st: structured_tensor.StructuredTensor, - default_field_name: path.Step) -> Mapping[path.Step, prensor.Prensor]: - """Creates a map of fields, to put in a child or root prensor.""" - return { - k: _structured_tensor_field_to_prensor( - st.field_value(k), default_field_name) for k in st.field_names() - } + st: structured_tensor.StructuredTensor, default_field_name: path.Step +) -> Mapping[path.Step, prensor.Prensor]: + """Creates a map of fields, to put in a child or root prensor.""" + return { + k: _structured_tensor_field_to_prensor(st.field_value(k), default_field_name) + for k in st.field_names() + } # _expand_dims requires special treatment for scalar StructuredTensors, because # it is not adding a partition dimension. Therefore, we have to expand the # dimension of each field explicitly. def _expand_dims_scalar(st: structured_tensor.StructuredTensor): - """_expand_dims for a scalar structured tensor.""" - new_shape = tf.constant([1], dtype=tf.int64) - new_fields = {k: _expand_dims(st.field_value(k), 0) for k in st.field_names()} - return structured_tensor.StructuredTensor.from_fields( - new_fields, shape=new_shape) + """_expand_dims for a scalar structured tensor.""" + new_shape = tf.constant([1], dtype=tf.int64) + new_fields = {k: _expand_dims(st.field_value(k), 0) for k in st.field_names()} + return structured_tensor.StructuredTensor.from_fields(new_fields, shape=new_shape) def _expand_dims_nonnegative_axis(axis, rank): - """Get the nonnegative axis according to the rules of tf.expand_dims.""" - # Implementation note: equivalent to get_positive_axis(axis, rank + 1) - if axis < 0: - new_axis = (1 + rank) + axis - if new_axis < 0: - # Note: this is unreachable in the current code. - raise ValueError("Axis out of range: " + str(axis)) - return new_axis - if axis > rank: - # Note: this is unreachable in the current code. - raise ValueError("Axis larger than rank: " + str(axis) + " > " + str(rank)) - return axis + """Get the nonnegative axis according to the rules of tf.expand_dims.""" + # Implementation note: equivalent to get_positive_axis(axis, rank + 1) + if axis < 0: + new_axis = (1 + rank) + axis + if new_axis < 0: + # Note: this is unreachable in the current code. + raise ValueError("Axis out of range: " + str(axis)) + return new_axis + if axis > rank: + # Note: this is unreachable in the current code. + raise ValueError("Axis larger than rank: " + str(axis) + " > " + str(rank)) + return axis def _expand_dims(st, axis): - """tf.expand_dims, but works on StructuredTensor too. - - Note: the implementation does not work if axis > 1, and will throw a - ValueError. - - Args: - st: a Tensor, RaggedTensor, or StructuredTensor. - axis: the axis to insert a dimension before. - - Returns: - a tensor with one more dimension (see tf.expand_dims). - - Raises: - ValueError: - if the axis is not valid. - """ - if not isinstance(st, structured_tensor.StructuredTensor): - return tf.expand_dims(st, axis) - nn_axis = _expand_dims_nonnegative_axis(axis, st.rank) - if st.rank == 0: - return _expand_dims_scalar(st) - if nn_axis == 0: - # Here, we can add a dimension 1 at the front. - nrows = st.nrows() - return st.partition_outer_dimension( - RowPartition.from_uniform_row_length(nrows, nrows)) - if nn_axis == 1: - # Again, by partitioning the first dimension into vectors of length 1, - # we can solve this problem. - nrows = st.nrows() - return st.partition_outer_dimension( - RowPartition.from_uniform_row_length( - tf.constant(1, dtype=nrows.dtype), nrows)) - # Note: this is unreachable in the current code. - raise ValueError("Unimplemented: non-negative axis > 1 for _expand_dims") + """tf.expand_dims, but works on StructuredTensor too. + + Note: the implementation does not work if axis > 1, and will throw a + ValueError. + + Args: + st: a Tensor, RaggedTensor, or StructuredTensor. + axis: the axis to insert a dimension before. + + Returns: + a tensor with one more dimension (see tf.expand_dims). + + Raises: + ValueError: + if the axis is not valid. + """ + if not isinstance(st, structured_tensor.StructuredTensor): + return tf.expand_dims(st, axis) + nn_axis = _expand_dims_nonnegative_axis(axis, st.rank) + if st.rank == 0: + return _expand_dims_scalar(st) + if nn_axis == 0: + # Here, we can add a dimension 1 at the front. + nrows = st.nrows() + return st.partition_outer_dimension( + RowPartition.from_uniform_row_length(nrows, nrows) + ) + if nn_axis == 1: + # Again, by partitioning the first dimension into vectors of length 1, + # we can solve this problem. + nrows = st.nrows() + return st.partition_outer_dimension( + RowPartition.from_uniform_row_length( + tf.constant(1, dtype=nrows.dtype), nrows + ) + ) + # Note: this is unreachable in the current code. + raise ValueError("Unimplemented: non-negative axis > 1 for _expand_dims") def _structured_tensor_field_to_prensor( - field_value: Union[structured_tensor.StructuredTensor, tf.RaggedTensor, - tf.Tensor], - default_field_name: path.Step) -> prensor.Prensor: - """Creates a ChildNodeTensor from a field in a structured tensor.""" - if isinstance(field_value, structured_tensor.StructuredTensor): - return _structured_tensor_to_child_prensor(field_value, default_field_name) - return _to_leaf_prensor(field_value, default_field_name) + field_value: Union[structured_tensor.StructuredTensor, tf.RaggedTensor, tf.Tensor], + default_field_name: path.Step, +) -> prensor.Prensor: + """Creates a ChildNodeTensor from a field in a structured tensor.""" + if isinstance(field_value, structured_tensor.StructuredTensor): + return _structured_tensor_to_child_prensor(field_value, default_field_name) + return _to_leaf_prensor(field_value, default_field_name) def _row_partition_to_child_node_tensor(row_partition: RowPartition): - """Creates a ChildNodeTensor from a RowPartition.""" - return prensor.ChildNodeTensor( - tf.cast(row_partition.value_rowids(), tf.int64), - is_repeated=True) - - -def _one_child_prensor(row_partition: RowPartition, - child_prensor: prensor.Prensor, - default_field_name: path.Step) -> prensor.Prensor: - """Creates a prensor with a ChildNodeTensor at the root with one child.""" - child_node_tensor = _row_partition_to_child_node_tensor(row_partition) - return prensor.create_prensor_from_root_and_children( - child_node_tensor, {default_field_name: child_prensor}) + """Creates a ChildNodeTensor from a RowPartition.""" + return prensor.ChildNodeTensor( + tf.cast(row_partition.value_rowids(), tf.int64), is_repeated=True + ) + + +def _one_child_prensor( + row_partition: RowPartition, + child_prensor: prensor.Prensor, + default_field_name: path.Step, +) -> prensor.Prensor: + """Creates a prensor with a ChildNodeTensor at the root with one child.""" + child_node_tensor = _row_partition_to_child_node_tensor(row_partition) + return prensor.create_prensor_from_root_and_children( + child_node_tensor, {default_field_name: child_prensor} + ) def _structured_tensor_to_child_prensor( - st: structured_tensor.StructuredTensor, - default_field_name: path.Step) -> prensor.Prensor: - """Creates a prensor with a ChildNodeTensor at the root.""" - row_partitions = st.row_partitions - if len(row_partitions) == 1: - child_st = st.merge_dims(0, 1) - row_partition = row_partitions[0] - return prensor.create_prensor_from_root_and_children( - _row_partition_to_child_node_tensor(row_partition), - _structured_tensor_prensor_map(child_st, default_field_name)) - if len(row_partitions) > 1: - row_partition = row_partitions[0] - child_st = st.merge_dims(0, 1) + st: structured_tensor.StructuredTensor, default_field_name: path.Step +) -> prensor.Prensor: + """Creates a prensor with a ChildNodeTensor at the root.""" + row_partitions = st.row_partitions + if len(row_partitions) == 1: + child_st = st.merge_dims(0, 1) + row_partition = row_partitions[0] + return prensor.create_prensor_from_root_and_children( + _row_partition_to_child_node_tensor(row_partition), + _structured_tensor_prensor_map(child_st, default_field_name), + ) + if len(row_partitions) > 1: + row_partition = row_partitions[0] + child_st = st.merge_dims(0, 1) + return _one_child_prensor( + row_partition, + _structured_tensor_to_child_prensor(child_st, default_field_name), + default_field_name, + ) + # This requires us to transform the scalar to a vector. + # The fields could be scalars or vectors. + # We need _expand_dims(...) to make this work. + return _structured_tensor_to_child_prensor(_expand_dims(st, 1), default_field_name) + + +def _to_leaf_prensor_helper( + rt: tf.RaggedTensor, default_field_name: path.Step +) -> prensor.Prensor: + """Converts a fully partitioned ragged tensor to a leaf prensor. + + It is assumed that this is a fully partitioned ragged tensor. Specifically, + the values at the end are a vector, not a 2D tensor. + + Args: + rt: a fully partitioned ragged tensor (see + _fully_partitioned_ragged_tensor). + default_field_name: a path.Step for unnamed dimensions. + + Returns: + a prensor, with a leaf as the root node. + """ + row_partition = rt._row_partition # pylint: disable=protected-access + if rt.ragged_rank == 1: + values = rt.values + leaf = prensor.LeafNodeTensor(row_partition.value_rowids(), values, True) + return prensor.create_prensor_from_root_and_children(leaf, {}) return _one_child_prensor( row_partition, - _structured_tensor_to_child_prensor(child_st, default_field_name), - default_field_name) - # This requires us to transform the scalar to a vector. - # The fields could be scalars or vectors. - # We need _expand_dims(...) to make this work. - return _structured_tensor_to_child_prensor( - _expand_dims(st, 1), default_field_name) - - -def _to_leaf_prensor_helper(rt: tf.RaggedTensor, - default_field_name: path.Step) -> prensor.Prensor: - """Converts a fully partitioned ragged tensor to a leaf prensor. - - It is assumed that this is a fully partitioned ragged tensor. Specifically, - the values at the end are a vector, not a 2D tensor. - - Args: - rt: a fully partitioned ragged tensor (see - _fully_partitioned_ragged_tensor). - default_field_name: a path.Step for unnamed dimensions. - - Returns: - a prensor, with a leaf as the root node. - """ - row_partition = rt._row_partition # pylint: disable=protected-access - if rt.ragged_rank == 1: - values = rt.values - leaf = prensor.LeafNodeTensor(row_partition.value_rowids(), values, True) - return prensor.create_prensor_from_root_and_children(leaf, {}) - return _one_child_prensor( - row_partition, _to_leaf_prensor_helper(rt.values, default_field_name), - default_field_name) + _to_leaf_prensor_helper(rt.values, default_field_name), + default_field_name, + ) def _partition_if_not_vector(values: tf.Tensor, dtype: tf.dtypes.DType): - """Creates a fully partitioned ragged tensor from a multidimensional tensor. - - If the tensor is 1D, then it is unchanged. - - Args: - values: the tensor to be transformed - dtype: the type of the row splits. - - Returns: - A 1D tensor or a ragged tensor. - - Raises: - ValueError: if the shape cannot be statically determined or is a scalar. - """ - values_shape = values.shape - assert values_shape is not None - values_rank = values_shape.rank - # values cannot have an unknown rank in a RaggedTensor field - # in a StructuredTensor. - assert values_rank is not None - if values_rank == 1: - return values - # This cannot happen inside a ragged tensor. - assert values_rank > 0 - return tf.RaggedTensor.from_tensor( - values, ragged_rank=values_rank - 1, row_splits_dtype=dtype) - - -def _fully_partitioned_ragged_tensor(rt: Union[tf.RaggedTensor, tf.Tensor], - dtype=tf.dtypes.int64): - """Creates a fully partitioned ragged tensor from a tensor or a ragged tensor. - - If given a tensor, it must be at least two-dimensional. - - A fully partitioned ragged tensor is: - 1. A ragged tensor. - 2. The final values are a vector. - - Args: - rt: input to coerce from RaggedTensor or Tensor. Must be at least 2D. - dtype: requested dtype for partitions: tf.int64 or tf.int32. - - Returns: - A ragged tensor where the flat values are a 1D tensor. - - Raises: - ValueError: if the tensor is 0D or 1D. - """ - if isinstance(rt, tf.RaggedTensor): - rt = rt.with_row_splits_dtype(dtype) - flattened_values = _partition_if_not_vector(rt.flat_values, dtype=dtype) - return rt.with_flat_values(flattened_values) - rt_shape = rt.shape - assert rt_shape is not None - rt_rank = rt_shape.rank - assert rt_rank is not None - if rt_rank < 2: - # Increase the rank if it is a scalar. - return _fully_partitioned_ragged_tensor(tf.expand_dims(rt, -1)) - return tf.RaggedTensor.from_tensor( - rt, ragged_rank=rt_rank - 1, row_splits_dtype=dtype) - - -def _to_leaf_prensor(rt: Union[tf.RaggedTensor, tf.Tensor], - default_field_name: path.Step) -> prensor.Prensor: - """Creates a leaf tensor from a ragged tensor or tensor.""" - return _to_leaf_prensor_helper( - _fully_partitioned_ragged_tensor(rt), default_field_name) + """Creates a fully partitioned ragged tensor from a multidimensional tensor. + + If the tensor is 1D, then it is unchanged. + + Args: + values: the tensor to be transformed + dtype: the type of the row splits. + + Returns: + A 1D tensor or a ragged tensor. + + Raises: + ValueError: if the shape cannot be statically determined or is a scalar. + """ + values_shape = values.shape + assert values_shape is not None + values_rank = values_shape.rank + # values cannot have an unknown rank in a RaggedTensor field + # in a StructuredTensor. + assert values_rank is not None + if values_rank == 1: + return values + # This cannot happen inside a ragged tensor. + assert values_rank > 0 + return tf.RaggedTensor.from_tensor( + values, ragged_rank=values_rank - 1, row_splits_dtype=dtype + ) + + +def _fully_partitioned_ragged_tensor( + rt: Union[tf.RaggedTensor, tf.Tensor], dtype=tf.dtypes.int64 +): + """Creates a fully partitioned ragged tensor from a tensor or a ragged tensor. + + If given a tensor, it must be at least two-dimensional. + + A fully partitioned ragged tensor is: + 1. A ragged tensor. + 2. The final values are a vector. + + Args: + rt: input to coerce from RaggedTensor or Tensor. Must be at least 2D. + dtype: requested dtype for partitions: tf.int64 or tf.int32. + + Returns: + A ragged tensor where the flat values are a 1D tensor. + + Raises: + ValueError: if the tensor is 0D or 1D. + """ + if isinstance(rt, tf.RaggedTensor): + rt = rt.with_row_splits_dtype(dtype) + flattened_values = _partition_if_not_vector(rt.flat_values, dtype=dtype) + return rt.with_flat_values(flattened_values) + rt_shape = rt.shape + assert rt_shape is not None + rt_rank = rt_shape.rank + assert rt_rank is not None + if rt_rank < 2: + # Increase the rank if it is a scalar. + return _fully_partitioned_ragged_tensor(tf.expand_dims(rt, -1)) + return tf.RaggedTensor.from_tensor( + rt, ragged_rank=rt_rank - 1, row_splits_dtype=dtype + ) + + +def _to_leaf_prensor( + rt: Union[tf.RaggedTensor, tf.Tensor], default_field_name: path.Step +) -> prensor.Prensor: + """Creates a leaf tensor from a ragged tensor or tensor.""" + return _to_leaf_prensor_helper( + _fully_partitioned_ragged_tensor(rt), default_field_name + ) diff --git a/struct2tensor/structured_tensor_to_prensor_test.py b/struct2tensor/structured_tensor_to_prensor_test.py index f84206c..efca2dd 100644 --- a/struct2tensor/structured_tensor_to_prensor_test.py +++ b/struct2tensor/structured_tensor_to_prensor_test.py @@ -17,564 +17,567 @@ import tensorflow.compat.v2 as tf from absl.testing import parameterized from tensorflow.python.ops.ragged.row_partition import ( - RowPartition, # pylint: disable=g-direct-tensorflow-import + RowPartition, # pylint: disable=g-direct-tensorflow-import ) from tensorflow.python.ops.structured import ( - structured_tensor, # pylint: disable=g-direct-tensorflow-import + structured_tensor, # pylint: disable=g-direct-tensorflow-import ) from struct2tensor import path, structured_tensor_to_prensor def _make_structured_tensor(shape, fields): - return structured_tensor.StructuredTensor.from_fields( - fields=fields, shape=shape) + return structured_tensor.StructuredTensor.from_fields(fields=fields, shape=shape) class StructuredTensorToPrensorTest(tf.test.TestCase, parameterized.TestCase): - - @parameterized.named_parameters([ - { - # Structure:{},{} - "testcase_name": "no_fields", - "shape": [2], - "fields": { - }, - }, - { - # Structure:[] - "testcase_name": "field_of_zeroes", - "shape": [0], - "fields": { - "t": np.zeros([0, 5]), - }, - "path_to_check": "t", - "values": [], - "parent_indices": [], - }, - { - # Structure:[][unknown ragged] - "testcase_name": "empty_2d", - "shape": [2, 2], - "fields": { - }, - }, - { - # Structure:{"r":[1,2]},{"r":[3]} - "testcase_name": "one_normal_field", - "shape": [2], - "fields": { - "r": tf.ragged.constant([[1, 2], [3]]), - }, - "path_to_check": "r", - "parent_indices": [0, 0, 1], - "values": [1, 2, 3], - }, - { - # Structure:{"r":[1, 2]},{"r":[3, 4]} - "testcase_name": "one_dense_field", - "shape": [2], - "fields": { - "r": tf.constant([[1, 2], [3, 4]]), - }, - "path_to_check": "r", - "parent_indices": [0, 0, 1, 1], - "values": [1, 2, 3, 4], - }, - { - # Structure:{"zzz":[{"data":[1, 2]}, {"data":[3]}]}, - # {"zzz":[{"data":[4, 5, 6]}, {"data":[7]}, - # {"data":[8, 9]}]}, - "testcase_name": "field_with_two_dimensions_first_dimension", - "shape": [2], - "fields": { - "zzz": tf.ragged.constant( - [[[1, 2], [3]], [[4, 5, 6], [7], [8, 9]]]), - - }, - "path_to_check": "zzz", - "root_size": 2, - "parent_indices": [0, 0, 1, 1, 1], - }, - { - # Structure:{"r":[{"data":[1, 2]}, {"data":[3]}]}, - # {"r":[{"data":[4, 5, 6]}, {"data":[7]}, {"data":[8, 9]}]}, - "testcase_name": "test_field_with_two_dimensions_second_dimension", - "shape": [2], - "fields": { - "r": tf.ragged.constant( - [[[1, 2], [3]], [[4, 5, 6], [7], [8, 9]]]), - - }, - "path_to_check": "r.data", - "parent_indices": [0, 0, 1, 2, 2, 2, 3, 4, 4], - }, - { - # Structure:{"foo":[{"bar":[1, 2]}, {"bar":[3]}]}, - # {"foo":[{"bar":[4, 5, 6]}, {"bar":[7]}, {"bar":[8, 9]}]}, - "testcase_name": "StructTensor_within_StructTensor", - "shape": [2], - "fields": { - "foo": - _make_structured_tensor([2, None], - {"bar": - tf.ragged.constant( - [[[1, 2], [3]], - [[4, 5, 6], [7], [8, 9]]])}), - }, - "path_to_check": "foo", - "root_size": 2, - "parent_indices": [0, 0, 1, 1, 1], - }, - { - "testcase_name": "multiple_normal_fields", - "shape": [2], - "fields": { - "a": tf.ragged.constant([[1, 2], [3]]), - "b": tf.ragged.constant([[4, 5], [6]]), - }, - "path_to_check": "a", - "parent_indices": [0, 0, 1], - }, - { - "testcase_name": "multiple_fields_with_two_dimensions", - "shape": [2], - "fields": { - "a": tf.ragged.constant( - [[[1, 2], [3]], [[4, 5, 6], [7], [8, 9]]]), - "b": tf.ragged.constant( - [[[1], []], [[2, 3], [4, 5, 6], [7, 8]]]), - }, - "path_to_check": "b.data", - "values": [1, 2, 3, 4, 5, 6, 7, 8], - }, - { - "testcase_name": "one_empty_field", - "shape": [0], - "fields": { - "t": [], - }, - "path_to_check": "t", - "values": [], - "parent_indices": [], - }, - { - "testcase_name": "shallow_field", - "shape": [2], - "fields": {"t": [5, 8]}, - "path_to_check": "t", - "values": [5, 8], - "parent_indices": [0, 1], - }, - { - "testcase_name": "shallow_message", - "shape": [2], - "fields": {"t": structured_tensor.StructuredTensor.from_pyval([{"a": [7]}, {"a": [8]}])}, - "path_to_check": "t", - "parent_indices": [0, 1], - }, - { - "testcase_name": "zero_list_tensor", - "shape": [], - "fields": { - }, - }, - { - # A scalar value for a scalar structured tensor. - "testcase_name": "one_scalar", - "shape": [], - "fields": { - "t": 3 - }, - "path_to_check": "t", - "parent_indices": [0], - "values": [3], - }, - { - # A repeated value for a scalar structured tensor. - "testcase_name": "scalar_repeated", - "shape": [], - "fields": { - "t": [3, 4, 5] - }, - "path_to_check": "t", - "parent_indices": [0, 0, 0], - "values": [3, 4, 5], - }, - - ]) # pyformat: disable - def testField(self, - shape, - fields, - path_to_check=None, - parent_indices=None, - root_size=None, - values=None): - struct = _make_structured_tensor(shape, fields) - prensor = structured_tensor_to_prensor.structured_tensor_to_prensor(struct) - if root_size is not None: - self.assertAllEqual(root_size, prensor.node.size) - if path_to_check is not None: - my_path = path.create_path(path_to_check) - descendant = prensor.get_descendant(my_path) - self.assertIsNotNone(descendant) - my_node = descendant.node - if parent_indices is not None: - self.assertAllEqual(my_node.parent_index, parent_indices) - if values is not None: - self.assertAllEqual(my_node.values, values) - - def testEmpty2DRagged(self): - struct = structured_tensor.StructuredTensor.from_fields( - fields={}, - shape=[2, None], - row_partitions=[RowPartition.from_row_splits([0, 3, 5])]) - p = structured_tensor_to_prensor.structured_tensor_to_prensor(struct) - child_node = p.get_child("data").node - self.assertAllEqual(child_node.parent_index, [0, 0, 0, 1, 1]) - - def testStructuredTensorCreation(self): - rt = tf.RaggedTensor.from_value_rowids( - tf.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1]) - - struct = _make_structured_tensor([2], {"r": rt}) - p = structured_tensor_to_prensor.structured_tensor_to_prensor(struct) - rt_value = p.get_descendant(path.create_path("r.data")) - self.assertAllEqual(rt_value.node.parent_index, [0, 0, 1, 1, 2, 2]) - self.assertAllEqual(rt_value.node.values, [1, 2, 3, 4, 5, 6]) - - def testDeepStructuredTensor(self): - rt = tf.RaggedTensor.from_value_rowids( - tf.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1]) - - struct = _make_structured_tensor([2], {"r": rt}) - struct_2 = struct.partition_outer_dimension( - RowPartition.from_row_splits([0, 1, 2])) - - p = structured_tensor_to_prensor.structured_tensor_to_prensor(struct_2) - rt_value = p.get_descendant(path.create_path("data.r.data")) - self.assertAllEqual(rt_value.node.parent_index, [0, 0, 1, 1, 2, 2]) - self.assertAllEqual(rt_value.node.values, [1, 2, 3, 4, 5, 6]) - p_data = p.get_descendant(path.create_path("data")) - self.assertAllEqual(p_data.node.parent_index, [0, 1]) - p_data_r = p.get_descendant(path.create_path("data.r")) - self.assertAllEqual(p_data_r.node.parent_index, [0, 0, 1]) - - def testTwoDeepStructuredTensor(self): - rt = tf.RaggedTensor.from_value_rowids( - tf.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1]) - - struct = _make_structured_tensor([2], {"r": rt}) - struct_2 = struct.partition_outer_dimension( - RowPartition.from_row_splits([0, 1, 2])) - struct_3 = struct_2.partition_outer_dimension( - RowPartition.from_row_splits([0, 1, 2])) - p = structured_tensor_to_prensor.structured_tensor_to_prensor(struct_3) - rt_value = p.get_descendant(path.create_path("data.data.r.data")) - self.assertAllEqual(rt_value.node.parent_index, [0, 0, 1, 1, 2, 2]) - self.assertAllEqual(rt_value.node.values, [1, 2, 3, 4, 5, 6]) - - def testFullyPartitionedRaggedTensor2D(self): - values = [3, 1, 4, 1, 5, 9, 2, 6] - row_splits = [0, 4, 4, 7, 8, 8] - rt = tf.RaggedTensor.from_row_splits(values, row_splits=row_splits) - rt_result = structured_tensor_to_prensor._fully_partitioned_ragged_tensor( - rt) - self.assertAllEqual(rt_result.values, values) - self.assertAllEqual(rt_result.row_splits, row_splits) - - def testFullyPartitionedRaggedTensor3D(self): - values = [3, 1, 4, 1, 5, 9, 2, 6] - row_splits1 = [0, 2, 4] - row_splits2 = [0, 2, 5, 7, 8] - rt = tf.RaggedTensor.from_row_splits( - tf.RaggedTensor.from_row_splits(values=values, row_splits=row_splits2), - row_splits=row_splits1) - rt_result = structured_tensor_to_prensor._fully_partitioned_ragged_tensor( - rt) - self.assertAllEqual(rt_result.values.row_splits, row_splits2) - self.assertAllEqual(rt_result.row_splits, row_splits1) - self.assertAllEqual(rt_result.values.values, values) - - def testFullyPartitionedRaggedTensor2DTensor(self): - rt_result = structured_tensor_to_prensor._fully_partitioned_ragged_tensor( - tf.constant([[1, 2], [3, 4]])) - self.assertAllEqual(rt_result.values, [1, 2, 3, 4]) - self.assertAllEqual(rt_result.row_splits, [0, 2, 4]) - - def testFullyPartitionedRaggedTensor3DTensor(self): - rt_result = structured_tensor_to_prensor._fully_partitioned_ragged_tensor( - tf.constant([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])) - self.assertAllEqual(rt_result.values.values, - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) - self.assertAllEqual(rt_result.row_splits, [0, 2, 4]) - self.assertAllEqual(rt_result.values.row_splits, [0, 3, 6, 9, 12]) - - def testFullyPartitionedRaggedTensor3DRaggedRank1(self): - row_splits = [0, 2, 4] - rt_result = structured_tensor_to_prensor._fully_partitioned_ragged_tensor( - tf.RaggedTensor.from_row_splits( - tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), - row_splits=row_splits)) - self.assertAllEqual(rt_result.values.values, - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) - self.assertAllEqual(rt_result.row_splits, [0, 2, 4]) - self.assertAllEqual(rt_result.values.row_splits, [0, 3, 6, 9, 12]) - - def testFullyPartitionedRaggedTensor4DRaggedRank2(self): - row_splits1 = [0, 2] - row_splits2 = [0, 1, 2] - rt_result = structured_tensor_to_prensor._fully_partitioned_ragged_tensor( - tf.RaggedTensor.from_row_splits( + @parameterized.named_parameters( + [ + { + # Structure:{},{} + "testcase_name": "no_fields", + "shape": [2], + "fields": {}, + }, + { + # Structure:[] + "testcase_name": "field_of_zeroes", + "shape": [0], + "fields": { + "t": np.zeros([0, 5]), + }, + "path_to_check": "t", + "values": [], + "parent_indices": [], + }, + { + # Structure:[][unknown ragged] + "testcase_name": "empty_2d", + "shape": [2, 2], + "fields": {}, + }, + { + # Structure:{"r":[1,2]},{"r":[3]} + "testcase_name": "one_normal_field", + "shape": [2], + "fields": { + "r": tf.ragged.constant([[1, 2], [3]]), + }, + "path_to_check": "r", + "parent_indices": [0, 0, 1], + "values": [1, 2, 3], + }, + { + # Structure:{"r":[1, 2]},{"r":[3, 4]} + "testcase_name": "one_dense_field", + "shape": [2], + "fields": { + "r": tf.constant([[1, 2], [3, 4]]), + }, + "path_to_check": "r", + "parent_indices": [0, 0, 1, 1], + "values": [1, 2, 3, 4], + }, + { + # Structure:{"zzz":[{"data":[1, 2]}, {"data":[3]}]}, + # {"zzz":[{"data":[4, 5, 6]}, {"data":[7]}, + # {"data":[8, 9]}]}, + "testcase_name": "field_with_two_dimensions_first_dimension", + "shape": [2], + "fields": { + "zzz": tf.ragged.constant( + [[[1, 2], [3]], [[4, 5, 6], [7], [8, 9]]] + ), + }, + "path_to_check": "zzz", + "root_size": 2, + "parent_indices": [0, 0, 1, 1, 1], + }, + { + # Structure:{"r":[{"data":[1, 2]}, {"data":[3]}]}, + # {"r":[{"data":[4, 5, 6]}, {"data":[7]}, {"data":[8, 9]}]}, + "testcase_name": "test_field_with_two_dimensions_second_dimension", + "shape": [2], + "fields": { + "r": tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6], [7], [8, 9]]]), + }, + "path_to_check": "r.data", + "parent_indices": [0, 0, 1, 2, 2, 2, 3, 4, 4], + }, + { + # Structure:{"foo":[{"bar":[1, 2]}, {"bar":[3]}]}, + # {"foo":[{"bar":[4, 5, 6]}, {"bar":[7]}, {"bar":[8, 9]}]}, + "testcase_name": "StructTensor_within_StructTensor", + "shape": [2], + "fields": { + "foo": _make_structured_tensor( + [2, None], + { + "bar": tf.ragged.constant( + [[[1, 2], [3]], [[4, 5, 6], [7], [8, 9]]] + ) + }, + ), + }, + "path_to_check": "foo", + "root_size": 2, + "parent_indices": [0, 0, 1, 1, 1], + }, + { + "testcase_name": "multiple_normal_fields", + "shape": [2], + "fields": { + "a": tf.ragged.constant([[1, 2], [3]]), + "b": tf.ragged.constant([[4, 5], [6]]), + }, + "path_to_check": "a", + "parent_indices": [0, 0, 1], + }, + { + "testcase_name": "multiple_fields_with_two_dimensions", + "shape": [2], + "fields": { + "a": tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6], [7], [8, 9]]]), + "b": tf.ragged.constant([[[1], []], [[2, 3], [4, 5, 6], [7, 8]]]), + }, + "path_to_check": "b.data", + "values": [1, 2, 3, 4, 5, 6, 7, 8], + }, + { + "testcase_name": "one_empty_field", + "shape": [0], + "fields": { + "t": [], + }, + "path_to_check": "t", + "values": [], + "parent_indices": [], + }, + { + "testcase_name": "shallow_field", + "shape": [2], + "fields": {"t": [5, 8]}, + "path_to_check": "t", + "values": [5, 8], + "parent_indices": [0, 1], + }, + { + "testcase_name": "shallow_message", + "shape": [2], + "fields": { + "t": structured_tensor.StructuredTensor.from_pyval( + [{"a": [7]}, {"a": [8]}] + ) + }, + "path_to_check": "t", + "parent_indices": [0, 1], + }, + { + "testcase_name": "zero_list_tensor", + "shape": [], + "fields": {}, + }, + { + # A scalar value for a scalar structured tensor. + "testcase_name": "one_scalar", + "shape": [], + "fields": {"t": 3}, + "path_to_check": "t", + "parent_indices": [0], + "values": [3], + }, + { + # A repeated value for a scalar structured tensor. + "testcase_name": "scalar_repeated", + "shape": [], + "fields": {"t": [3, 4, 5]}, + "path_to_check": "t", + "parent_indices": [0, 0, 0], + "values": [3, 4, 5], + }, + ] + ) # pyformat: disable + def testField( + self, + shape, + fields, + path_to_check=None, + parent_indices=None, + root_size=None, + values=None, + ): + struct = _make_structured_tensor(shape, fields) + prensor = structured_tensor_to_prensor.structured_tensor_to_prensor(struct) + if root_size is not None: + self.assertAllEqual(root_size, prensor.node.size) + if path_to_check is not None: + my_path = path.create_path(path_to_check) + descendant = prensor.get_descendant(my_path) + self.assertIsNotNone(descendant) + my_node = descendant.node + if parent_indices is not None: + self.assertAllEqual(my_node.parent_index, parent_indices) + if values is not None: + self.assertAllEqual(my_node.values, values) + + def testEmpty2DRagged(self): + struct = structured_tensor.StructuredTensor.from_fields( + fields={}, + shape=[2, None], + row_partitions=[RowPartition.from_row_splits([0, 3, 5])], + ) + p = structured_tensor_to_prensor.structured_tensor_to_prensor(struct) + child_node = p.get_child("data").node + self.assertAllEqual(child_node.parent_index, [0, 0, 0, 1, 1]) + + def testStructuredTensorCreation(self): + rt = tf.RaggedTensor.from_value_rowids( + tf.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1] + ) + + struct = _make_structured_tensor([2], {"r": rt}) + p = structured_tensor_to_prensor.structured_tensor_to_prensor(struct) + rt_value = p.get_descendant(path.create_path("r.data")) + self.assertAllEqual(rt_value.node.parent_index, [0, 0, 1, 1, 2, 2]) + self.assertAllEqual(rt_value.node.values, [1, 2, 3, 4, 5, 6]) + + def testDeepStructuredTensor(self): + rt = tf.RaggedTensor.from_value_rowids( + tf.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1] + ) + + struct = _make_structured_tensor([2], {"r": rt}) + struct_2 = struct.partition_outer_dimension( + RowPartition.from_row_splits([0, 1, 2]) + ) + + p = structured_tensor_to_prensor.structured_tensor_to_prensor(struct_2) + rt_value = p.get_descendant(path.create_path("data.r.data")) + self.assertAllEqual(rt_value.node.parent_index, [0, 0, 1, 1, 2, 2]) + self.assertAllEqual(rt_value.node.values, [1, 2, 3, 4, 5, 6]) + p_data = p.get_descendant(path.create_path("data")) + self.assertAllEqual(p_data.node.parent_index, [0, 1]) + p_data_r = p.get_descendant(path.create_path("data.r")) + self.assertAllEqual(p_data_r.node.parent_index, [0, 0, 1]) + + def testTwoDeepStructuredTensor(self): + rt = tf.RaggedTensor.from_value_rowids( + tf.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1] + ) + + struct = _make_structured_tensor([2], {"r": rt}) + struct_2 = struct.partition_outer_dimension( + RowPartition.from_row_splits([0, 1, 2]) + ) + struct_3 = struct_2.partition_outer_dimension( + RowPartition.from_row_splits([0, 1, 2]) + ) + p = structured_tensor_to_prensor.structured_tensor_to_prensor(struct_3) + rt_value = p.get_descendant(path.create_path("data.data.r.data")) + self.assertAllEqual(rt_value.node.parent_index, [0, 0, 1, 1, 2, 2]) + self.assertAllEqual(rt_value.node.values, [1, 2, 3, 4, 5, 6]) + + def testFullyPartitionedRaggedTensor2D(self): + values = [3, 1, 4, 1, 5, 9, 2, 6] + row_splits = [0, 4, 4, 7, 8, 8] + rt = tf.RaggedTensor.from_row_splits(values, row_splits=row_splits) + rt_result = structured_tensor_to_prensor._fully_partitioned_ragged_tensor(rt) + self.assertAllEqual(rt_result.values, values) + self.assertAllEqual(rt_result.row_splits, row_splits) + + def testFullyPartitionedRaggedTensor3D(self): + values = [3, 1, 4, 1, 5, 9, 2, 6] + row_splits1 = [0, 2, 4] + row_splits2 = [0, 2, 5, 7, 8] + rt = tf.RaggedTensor.from_row_splits( + tf.RaggedTensor.from_row_splits(values=values, row_splits=row_splits2), + row_splits=row_splits1, + ) + rt_result = structured_tensor_to_prensor._fully_partitioned_ragged_tensor(rt) + self.assertAllEqual(rt_result.values.row_splits, row_splits2) + self.assertAllEqual(rt_result.row_splits, row_splits1) + self.assertAllEqual(rt_result.values.values, values) + + def testFullyPartitionedRaggedTensor2DTensor(self): + rt_result = structured_tensor_to_prensor._fully_partitioned_ragged_tensor( + tf.constant([[1, 2], [3, 4]]) + ) + self.assertAllEqual(rt_result.values, [1, 2, 3, 4]) + self.assertAllEqual(rt_result.row_splits, [0, 2, 4]) + + def testFullyPartitionedRaggedTensor3DTensor(self): + rt_result = structured_tensor_to_prensor._fully_partitioned_ragged_tensor( + tf.constant([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) + ) + self.assertAllEqual( + rt_result.values.values, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + ) + self.assertAllEqual(rt_result.row_splits, [0, 2, 4]) + self.assertAllEqual(rt_result.values.row_splits, [0, 3, 6, 9, 12]) + + def testFullyPartitionedRaggedTensor3DRaggedRank1(self): + row_splits = [0, 2, 4] + rt_result = structured_tensor_to_prensor._fully_partitioned_ragged_tensor( tf.RaggedTensor.from_row_splits( - tf.constant([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, - 12]]]), - row_splits=row_splits2), - row_splits=row_splits1)) - self.assertAllEqual(rt_result.row_splits, [0, 2]) - self.assertAllEqual(rt_result.values.row_splits, [0, 1, 2]) - self.assertAllEqual(rt_result.values.values.row_splits, [0, 2, 4]) - self.assertAllEqual(rt_result.values.values.values.row_splits, - [0, 3, 6, 9, 12]) - self.assertAllEqual(rt_result.values.values.values.values, - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) - - def testPartitionIfNotVector3D(self): - t = tf.constant([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) - rt_result = structured_tensor_to_prensor._partition_if_not_vector( - t, tf.int64) - self.assertAllEqual(rt_result.row_splits, [0, 2, 4]) - self.assertAllEqual(rt_result.values.row_splits, [0, 3, 6, 9, 12]) - self.assertAllEqual(rt_result.values.values, - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), + row_splits=row_splits, + ) + ) + self.assertAllEqual( + rt_result.values.values, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + ) + self.assertAllEqual(rt_result.row_splits, [0, 2, 4]) + self.assertAllEqual(rt_result.values.row_splits, [0, 3, 6, 9, 12]) + + def testFullyPartitionedRaggedTensor4DRaggedRank2(self): + row_splits1 = [0, 2] + row_splits2 = [0, 1, 2] + rt_result = structured_tensor_to_prensor._fully_partitioned_ragged_tensor( + tf.RaggedTensor.from_row_splits( + tf.RaggedTensor.from_row_splits( + tf.constant([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]), + row_splits=row_splits2, + ), + row_splits=row_splits1, + ) + ) + self.assertAllEqual(rt_result.row_splits, [0, 2]) + self.assertAllEqual(rt_result.values.row_splits, [0, 1, 2]) + self.assertAllEqual(rt_result.values.values.row_splits, [0, 2, 4]) + self.assertAllEqual(rt_result.values.values.values.row_splits, [0, 3, 6, 9, 12]) + self.assertAllEqual( + rt_result.values.values.values.values, + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + ) + + def testPartitionIfNotVector3D(self): + t = tf.constant([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) + rt_result = structured_tensor_to_prensor._partition_if_not_vector(t, tf.int64) + self.assertAllEqual(rt_result.row_splits, [0, 2, 4]) + self.assertAllEqual(rt_result.values.row_splits, [0, 3, 6, 9, 12]) + self.assertAllEqual( + rt_result.values.values, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + ) class StructuredTensorExpandDims(tf.test.TestCase, parameterized.TestCase): - - @parameterized.named_parameters([ - { - # Structure:{},{} - "testcase_name": "no_fields", - "shape": [2], - "fields": { - }, - "axis": 0, - }, - { - # Structure:[] - "testcase_name": "field_of_zeroes", - "shape": [0], - "fields": { - "t": np.zeros([0, 5]), - }, - "axis": -1, - }, - { - # Structure:[][unknown ragged] - "testcase_name": "empty_2d", - "shape": [2, 2], - "fields": { - }, - "axis": 1, - }, - { - # Structure:{"r":[1,2]},{"r":[3]} - "testcase_name": "one_normal_field", - "shape": [2], - "fields": { - "r": tf.ragged.constant([[1, 2], [3]]), - }, - "axis": 1, - "target_field": "r", - "expected": [[[1, 2]], [[3]]], - }, - { - # Structure:{"r":[1, 2]},{"r":[3, 4]} - "testcase_name": "one_dense_field_axis_0", - "shape": [2], - "fields": { - "r": tf.constant([[1, 2], [3, 4]]), - }, - "axis": 0, - "target_field": "r", - "expected": [[[1, 2], [3, 4]]], - }, - { - # Structure:{"r":[1, 2]},{"r":[3, 4]} - "testcase_name": "one_dense_field_axis_1", - "shape": [2], - "fields": { - "r": tf.constant([[1, 2], [3, 4]]), - }, - "axis": 1, - "target_field": "r", - "expected": [[[1, 2]], [[3, 4]]] - }, - { - # Structure:{"zzz":[{"data":[1, 2]}, {"data":[3]}]}, - # {"zzz":[{"data":[4, 5, 6]}, {"data":[7]}, - # {"data":[8, 9]}]}, - "testcase_name": "field_with_two_dimensions_axis_0", - "shape": [2], - "fields": { - "zzz": tf.ragged.constant( - [[[1, 2], [3]], [[4, 5, 6], [7], [8, 9]]]), - - }, - "axis": 0, - "target_field": "zzz", - "expected": [[[[1, 2], [3]], [[4, 5, 6], [7], [8, 9]]]], - - }, - { - # Structure:{"r":[{"data":[1, 2]}, {"data":[3]}]}, - # {"r":[{"data":[4, 5, 6]}, {"data":[7]}, {"data":[8, 9]}]}, - "testcase_name": "test_field_with_two_dimensions_axis_1", - "shape": [2], - "fields": { - "r": tf.ragged.constant( - [[[1, 2], [3]], [[4, 5, 6], [7], [8, 9]]]), - - }, - "axis": 1, - "target_field": "r", - "expected": [ - [[[1, 2], [3]]], - [[[4, 5, 6], [7], [8, 9]]] - ] - }, - { - # Structure:{"foo":[{"bar":[1, 2]}, {"bar":[3]}]}, - # {"foo":[{"bar":[4, 5, 6]}, {"bar":[7]}, {"bar":[8, 9]}]}, - "testcase_name": "StructTensor_within_StructTensor", - "shape": [2], - "fields": { - "foo": - _make_structured_tensor([2, None], - {"bar": - tf.ragged.constant( - [[[1, 2], [3]], - [[4, 5, 6], [7], [8, 9]]])}), - }, - "axis": 0, - }, - { - "testcase_name": "multiple_normal_fields", - "shape": [2], - "fields": { - "a": tf.ragged.constant([[1, 2], [3]]), - "b": tf.ragged.constant([[4, 5], [6]]), - }, - "axis": 1, - "target_field": "a", - "expected": [[[1, 2]], [[3]]], - }, - { - "testcase_name": "multiple_fields_with_two_dimensions", - "shape": [2], - "fields": { - "a": tf.ragged.constant( - [[[1, 2], [3]], [[4, 5, 6], [7], [8, 9]]]), - "b": tf.ragged.constant( - [[[1], []], [[2, 3], [4, 5, 6], [7, 8]]]), - }, - "axis": 1, - "target_field": "a", - "expected": [[[[1, 2], [3]]], [[[4, 5, 6], [7], [8, 9]]]], - }, - { - "testcase_name": "one_empty_field", - "shape": [0], - "fields": { - "t": [], - }, - "axis": 0, - }, - { - "testcase_name": "shallow_field", - "shape": [2], - "fields": {"t": [5, 8]}, - "axis": 0, - "target_field": "t", - "expected": [[5, 8]], - }, - { - "testcase_name": "shallow_message", - "shape": [2], - "fields": {"t": structured_tensor.StructuredTensor.from_pyval([{"a": [7]}, {"a": [8]}])}, - "axis": 0, - }, - { - "testcase_name": "zero_list_tensor", - "shape": [], - "fields": { - }, - "axis": 0, - }, - { - # A scalar value for a scalar structured tensor. - "testcase_name": "one_scalar", - "shape": [], - "fields": { - "t": 3 - }, - "axis": 0, - "target_field": "t", - "expected": [3], - }, - { - # A repeated value for a scalar structured tensor. - "testcase_name": "scalar_repeated", - "shape": [], - "fields": { - "t": [3, 4, 5] - }, - "axis": 0, - "target_field": "t", - "expected": [[3, 4, 5]], - }, - ]) # pyformat: disable - def testField(self, shape, fields, axis=0, target_field=None, expected=None): - struct = _make_structured_tensor(shape, fields) - actual = structured_tensor_to_prensor._expand_dims(struct, axis=axis) - if expected is not None: - self.assertAllEqual(actual.field_value(target_field), expected) - - @parameterized.named_parameters([ - { - # Structure:{},{} - "testcase_name": "large_axis_small_rank", - "shape": [2], - "fields": { - }, - "axis": 2, - "reg_error": "Axis larger than rank: 2 > 1", - }, - { - # Structure:[] - "testcase_name": "big_negative", - "shape": [0], - "fields": { - "t": np.zeros([0, 5]), - }, - "axis": -3, - "reg_error": "Axis out of range: -3", - }, - { - # Currently, axis >= 2 is disallowed. - "testcase_name": "temporary_test_limiting_performance", - "shape": [3, 3, 3, 3], - "fields": { - }, - "axis": 2, - "reg_error": "Unimplemented: .* > 1 for _expand_dims", - }, - ]) # pyformat: disable - def testRaises(self, shape, fields, axis=0, reg_error=None): - struct = _make_structured_tensor(shape, fields) - with self.assertRaisesRegex(ValueError, reg_error): - structured_tensor_to_prensor._expand_dims(struct, axis=axis) + @parameterized.named_parameters( + [ + { + # Structure:{},{} + "testcase_name": "no_fields", + "shape": [2], + "fields": {}, + "axis": 0, + }, + { + # Structure:[] + "testcase_name": "field_of_zeroes", + "shape": [0], + "fields": { + "t": np.zeros([0, 5]), + }, + "axis": -1, + }, + { + # Structure:[][unknown ragged] + "testcase_name": "empty_2d", + "shape": [2, 2], + "fields": {}, + "axis": 1, + }, + { + # Structure:{"r":[1,2]},{"r":[3]} + "testcase_name": "one_normal_field", + "shape": [2], + "fields": { + "r": tf.ragged.constant([[1, 2], [3]]), + }, + "axis": 1, + "target_field": "r", + "expected": [[[1, 2]], [[3]]], + }, + { + # Structure:{"r":[1, 2]},{"r":[3, 4]} + "testcase_name": "one_dense_field_axis_0", + "shape": [2], + "fields": { + "r": tf.constant([[1, 2], [3, 4]]), + }, + "axis": 0, + "target_field": "r", + "expected": [[[1, 2], [3, 4]]], + }, + { + # Structure:{"r":[1, 2]},{"r":[3, 4]} + "testcase_name": "one_dense_field_axis_1", + "shape": [2], + "fields": { + "r": tf.constant([[1, 2], [3, 4]]), + }, + "axis": 1, + "target_field": "r", + "expected": [[[1, 2]], [[3, 4]]], + }, + { + # Structure:{"zzz":[{"data":[1, 2]}, {"data":[3]}]}, + # {"zzz":[{"data":[4, 5, 6]}, {"data":[7]}, + # {"data":[8, 9]}]}, + "testcase_name": "field_with_two_dimensions_axis_0", + "shape": [2], + "fields": { + "zzz": tf.ragged.constant( + [[[1, 2], [3]], [[4, 5, 6], [7], [8, 9]]] + ), + }, + "axis": 0, + "target_field": "zzz", + "expected": [[[[1, 2], [3]], [[4, 5, 6], [7], [8, 9]]]], + }, + { + # Structure:{"r":[{"data":[1, 2]}, {"data":[3]}]}, + # {"r":[{"data":[4, 5, 6]}, {"data":[7]}, {"data":[8, 9]}]}, + "testcase_name": "test_field_with_two_dimensions_axis_1", + "shape": [2], + "fields": { + "r": tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6], [7], [8, 9]]]), + }, + "axis": 1, + "target_field": "r", + "expected": [[[[1, 2], [3]]], [[[4, 5, 6], [7], [8, 9]]]], + }, + { + # Structure:{"foo":[{"bar":[1, 2]}, {"bar":[3]}]}, + # {"foo":[{"bar":[4, 5, 6]}, {"bar":[7]}, {"bar":[8, 9]}]}, + "testcase_name": "StructTensor_within_StructTensor", + "shape": [2], + "fields": { + "foo": _make_structured_tensor( + [2, None], + { + "bar": tf.ragged.constant( + [[[1, 2], [3]], [[4, 5, 6], [7], [8, 9]]] + ) + }, + ), + }, + "axis": 0, + }, + { + "testcase_name": "multiple_normal_fields", + "shape": [2], + "fields": { + "a": tf.ragged.constant([[1, 2], [3]]), + "b": tf.ragged.constant([[4, 5], [6]]), + }, + "axis": 1, + "target_field": "a", + "expected": [[[1, 2]], [[3]]], + }, + { + "testcase_name": "multiple_fields_with_two_dimensions", + "shape": [2], + "fields": { + "a": tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6], [7], [8, 9]]]), + "b": tf.ragged.constant([[[1], []], [[2, 3], [4, 5, 6], [7, 8]]]), + }, + "axis": 1, + "target_field": "a", + "expected": [[[[1, 2], [3]]], [[[4, 5, 6], [7], [8, 9]]]], + }, + { + "testcase_name": "one_empty_field", + "shape": [0], + "fields": { + "t": [], + }, + "axis": 0, + }, + { + "testcase_name": "shallow_field", + "shape": [2], + "fields": {"t": [5, 8]}, + "axis": 0, + "target_field": "t", + "expected": [[5, 8]], + }, + { + "testcase_name": "shallow_message", + "shape": [2], + "fields": { + "t": structured_tensor.StructuredTensor.from_pyval( + [{"a": [7]}, {"a": [8]}] + ) + }, + "axis": 0, + }, + { + "testcase_name": "zero_list_tensor", + "shape": [], + "fields": {}, + "axis": 0, + }, + { + # A scalar value for a scalar structured tensor. + "testcase_name": "one_scalar", + "shape": [], + "fields": {"t": 3}, + "axis": 0, + "target_field": "t", + "expected": [3], + }, + { + # A repeated value for a scalar structured tensor. + "testcase_name": "scalar_repeated", + "shape": [], + "fields": {"t": [3, 4, 5]}, + "axis": 0, + "target_field": "t", + "expected": [[3, 4, 5]], + }, + ] + ) # pyformat: disable + def testField(self, shape, fields, axis=0, target_field=None, expected=None): + struct = _make_structured_tensor(shape, fields) + actual = structured_tensor_to_prensor._expand_dims(struct, axis=axis) + if expected is not None: + self.assertAllEqual(actual.field_value(target_field), expected) + + @parameterized.named_parameters( + [ + { + # Structure:{},{} + "testcase_name": "large_axis_small_rank", + "shape": [2], + "fields": {}, + "axis": 2, + "reg_error": "Axis larger than rank: 2 > 1", + }, + { + # Structure:[] + "testcase_name": "big_negative", + "shape": [0], + "fields": { + "t": np.zeros([0, 5]), + }, + "axis": -3, + "reg_error": "Axis out of range: -3", + }, + { + # Currently, axis >= 2 is disallowed. + "testcase_name": "temporary_test_limiting_performance", + "shape": [3, 3, 3, 3], + "fields": {}, + "axis": 2, + "reg_error": "Unimplemented: .* > 1 for _expand_dims", + }, + ] + ) # pyformat: disable + def testRaises(self, shape, fields, axis=0, reg_error=None): + struct = _make_structured_tensor(shape, fields) + with self.assertRaisesRegex(ValueError, reg_error): + structured_tensor_to_prensor._expand_dims(struct, axis=axis) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/struct2tensor/test/expression_test_util.py b/struct2tensor/test/expression_test_util.py index d8a0edc..5ebabe3 100644 --- a/struct2tensor/test/expression_test_util.py +++ b/struct2tensor/test/expression_test_util.py @@ -27,146 +27,154 @@ def calculate_value_slowly( expr: expression.Expression, destinations: Optional[Sequence[expression.Expression]] = None, - options: Optional[calculate_options.Options] = None) -> prensor.NodeTensor: - """A calculation of the node tensor of an expression, without optimization. - - This will not do any common subexpression elimination or caching of - node tensors, and will likely be very slow for larger operations. - - Args: - expr: The expression to calculate. - destinations: Where the calculation will be used (None implies directly) - options: Calculation options for individual calculations - - Returns: - The node tensor of the expression. - """ - new_options = calculate_options.get_default_options( - ) if options is None else options - - source_node_tensors = [ - calculate_value_slowly(x, [expr], new_options) - for x in expr.get_source_expressions() - ] - real_dest = [] if destinations is None else destinations - return expr.calculate(source_node_tensors, real_dest, new_options) - - -def calculate_list_map(expr: expression.Expression, - evaluator, - options: Optional[calculate_options.Options] = None): - """Calculate a map from paths to nested lists, representing the leafs.""" - [my_prensor] = calculate.calculate_prensors([expr], options=options) - if not options: - options = calculate_options.get_default_options() - ragged_tensor_map = my_prensor.get_ragged_tensors(options) - string_tensor_map = {str(k): v for k, v in ragged_tensor_map.items()} - string_np_map = evaluator.evaluate(string_tensor_map) - return {k: v.to_list() for k, v in string_np_map.items()} + options: Optional[calculate_options.Options] = None, +) -> prensor.NodeTensor: + """A calculation of the node tensor of an expression, without optimization. - -class MockExpression(expression.Expression): - """The mock expression is designed to test calculations.""" - - def __init__(self, - is_repeated: bool, - my_type: Optional[tf.DType], - name: Optional[str] = None, - source_expressions: Optional[Sequence["MockExpression"]] = None, - calculate_output: Optional[prensor.NodeTensor] = None, - calculate_is_identity: bool = False, - children: Optional[Mapping[path.Step, - expression.Expression]] = None, - known_field_names: Optional[FrozenSet[path.Step]] = None, - schema_feature: Optional[schema_pb2.Feature] = None): - """Initialize an expression. + This will not do any common subexpression elimination or caching of + node tensors, and will likely be very slow for larger operations. Args: - is_repeated: whether the output is_repeated. - my_type: what my type is. - name: the name of this expression. - source_expressions: the source expressions. - calculate_output: the output returned. - calculate_is_identity: if this returns the identity. - children: the children of this expression - known_field_names: the known children of this expression. - schema_feature: schema information about the feature. + expr: The expression to calculate. + destinations: Where the calculation will be used (None implies directly) + options: Calculation options for individual calculations + + Returns: + The node tensor of the expression. """ - super().__init__(is_repeated, my_type, schema_feature=schema_feature) - self._name = "Unknown" if name is None else name - self._source_expressions = [] - if source_expressions is not None: - self._source_expressions = source_expressions - self._expected_source_tensors = [ - x.calculate_output for x in self._source_expressions + new_options = ( + calculate_options.get_default_options() if options is None else options + ) + + source_node_tensors = [ + calculate_value_slowly(x, [expr], new_options) + for x in expr.get_source_expressions() ] - self._calculate_output = calculate_output - self._calculate_is_identity = calculate_is_identity - self._children = {} if children is None else children - if known_field_names is None: - self._known_field_names = frozenset(self._children.keys()) + real_dest = [] if destinations is None else destinations + return expr.calculate(source_node_tensors, real_dest, new_options) + + +def calculate_list_map( + expr: expression.Expression, + evaluator, + options: Optional[calculate_options.Options] = None, +): + """Calculate a map from paths to nested lists, representing the leafs.""" + [my_prensor] = calculate.calculate_prensors([expr], options=options) + if not options: + options = calculate_options.get_default_options() + ragged_tensor_map = my_prensor.get_ragged_tensors(options) + string_tensor_map = {str(k): v for k, v in ragged_tensor_map.items()} + string_np_map = evaluator.evaluate(string_tensor_map) + return {k: v.to_list() for k, v in string_np_map.items()} + + +class MockExpression(expression.Expression): + """The mock expression is designed to test calculations.""" + + def __init__( + self, + is_repeated: bool, + my_type: Optional[tf.DType], + name: Optional[str] = None, + source_expressions: Optional[Sequence["MockExpression"]] = None, + calculate_output: Optional[prensor.NodeTensor] = None, + calculate_is_identity: bool = False, + children: Optional[Mapping[path.Step, expression.Expression]] = None, + known_field_names: Optional[FrozenSet[path.Step]] = None, + schema_feature: Optional[schema_pb2.Feature] = None, + ): + """Initialize an expression. + + Args: + is_repeated: whether the output is_repeated. + my_type: what my type is. + name: the name of this expression. + source_expressions: the source expressions. + calculate_output: the output returned. + calculate_is_identity: if this returns the identity. + children: the children of this expression + known_field_names: the known children of this expression. + schema_feature: schema information about the feature. + """ + super().__init__(is_repeated, my_type, schema_feature=schema_feature) + self._name = "Unknown" if name is None else name + self._source_expressions = [] + if source_expressions is not None: + self._source_expressions = source_expressions + self._expected_source_tensors = [ + x.calculate_output for x in self._source_expressions + ] + self._calculate_output = calculate_output + self._calculate_is_identity = calculate_is_identity + self._children = {} if children is None else children + if known_field_names is None: + self._known_field_names = frozenset(self._children.keys()) + else: + self._known_field_names = known_field_names + + @property + def calculate_output(self): + """The output returned by this expression.""" + if self._calculate_output is None: + raise ValueError(f"Did not specify calculate_output for {self._name}") + return self._calculate_output + + def get_source_expressions(self) -> Sequence[expression.Expression]: + return self._source_expressions + + def calculate( + self, + source_tensors: Sequence[prensor.NodeTensor], + destinations: Sequence[expression.Expression], + options: calculate_options.Options, + side_info: Optional[prensor.Prensor] = None, + ) -> prensor.NodeTensor: + if len(source_tensors) != len(self._expected_source_tensors): + raise ValueError(f"Unexpected number of inputs for {self._name}.") + for i in range(len(source_tensors)): + if self._expected_source_tensors[i] is not source_tensors[i]: + raise ValueError("Error calculating " + self._name) + return self._calculate_output + + def calculation_is_identity(self) -> bool: + return self._calculate_is_identity + + def calculation_equal(self, expr: expression.Expression) -> bool: + return self.calculation_is_identity() and expr.calculation_is_identity() + + def _get_child_impl(self, field_name: path.Step) -> Optional[expression.Expression]: + return self._children.get(field_name) + + def known_field_names(self) -> FrozenSet[path.Step]: + return self._known_field_names + + def __str__(self) -> str: + return str(self._name) + + +def get_mock_leaf( + is_repeated: bool, + my_type: tf.DType, + name: Optional[str] = None, + source_expressions: Optional[Sequence[MockExpression]] = None, + calculate_is_identity: bool = False, +): + """Gets a leaf expression.""" + if calculate_is_identity: + calculate_output = source_expressions[0].calculate_output else: - self._known_field_names = known_field_names - - @property - def calculate_output(self): - """The output returned by this expression.""" - if self._calculate_output is None: - raise ValueError(f"Did not specify calculate_output for {self._name}") - return self._calculate_output - - def get_source_expressions(self) -> Sequence[expression.Expression]: - return self._source_expressions - - def calculate( - self, - source_tensors: Sequence[prensor.NodeTensor], - destinations: Sequence[expression.Expression], - options: calculate_options.Options, - side_info: Optional[prensor.Prensor] = None) -> prensor.NodeTensor: - if len(source_tensors) != len(self._expected_source_tensors): - raise ValueError(f"Unexpected number of inputs for {self._name}.") - for i in range(len(source_tensors)): - if self._expected_source_tensors[i] is not source_tensors[i]: - raise ValueError("Error calculating " + self._name) - return self._calculate_output - - def calculation_is_identity(self) -> bool: - return self._calculate_is_identity - - def calculation_equal(self, expr: expression.Expression) -> bool: - return self.calculation_is_identity() and expr.calculation_is_identity() - - def _get_child_impl(self, - field_name: path.Step) -> Optional[expression.Expression]: - return self._children.get(field_name) - - def known_field_names(self) -> FrozenSet[path.Step]: - return self._known_field_names - - def __str__(self) -> str: - return str(self._name) - - -def get_mock_leaf(is_repeated: bool, - my_type: tf.DType, - name: Optional[str] = None, - source_expressions: Optional[Sequence[MockExpression]] = None, - calculate_is_identity: bool = False): - """Gets a leaf expression.""" - if calculate_is_identity: - calculate_output = source_expressions[0].calculate_output - else: - calculate_output = prensor.LeafNodeTensor( - tf.constant([], dtype=tf.int64), tf.constant([], dtype=my_type), - is_repeated) - return MockExpression( - is_repeated, - my_type, - name=name, - source_expressions=source_expressions, - calculate_output=calculate_output, - calculate_is_identity=calculate_is_identity) + calculate_output = prensor.LeafNodeTensor( + tf.constant([], dtype=tf.int64), tf.constant([], dtype=my_type), is_repeated + ) + return MockExpression( + is_repeated, + my_type, + name=name, + source_expressions=source_expressions, + calculate_output=calculate_output, + calculate_is_identity=calculate_is_identity, + ) def get_mock_broken_leaf( @@ -176,36 +184,40 @@ def get_mock_broken_leaf( actual_type: tf.DType, name: Optional[str] = None, source_expressions: Optional[Sequence[MockExpression]] = None, - calculate_is_identity: bool = False): - """Gets a leaf expression flexible enough not to typecheck. - - If declared_is_repeated != actual_is_repeated, - or declared_type != actual_type, then this will not typecheck - when _ExpressionNode.calculate() is called. - - Args: - declared_is_repeated: the is_repeated of the expression. - declared_type: the type of the expression. - actual_is_repeated: the is_repeated of the NodeTensor. - actual_type: the type of the NodeTensor. - name: a name of the expression. - source_expressions: the result of get_source expressions() - calculate_is_identity: true iff this should say it is the identity. - - Returns: - An expression. - - """ - if calculate_is_identity: - calculate_output = source_expressions[0].calculate_output - else: - calculate_output = prensor.LeafNodeTensor( - tf.constant([], dtype=tf.int64), tf.constant([], dtype=actual_type), - actual_is_repeated) - return MockExpression( - declared_is_repeated, - declared_type, - name=name, - source_expressions=source_expressions, - calculate_output=calculate_output, - calculate_is_identity=calculate_is_identity) + calculate_is_identity: bool = False, +): + """Gets a leaf expression flexible enough not to typecheck. + + If declared_is_repeated != actual_is_repeated, + or declared_type != actual_type, then this will not typecheck + when _ExpressionNode.calculate() is called. + + Args: + declared_is_repeated: the is_repeated of the expression. + declared_type: the type of the expression. + actual_is_repeated: the is_repeated of the NodeTensor. + actual_type: the type of the NodeTensor. + name: a name of the expression. + source_expressions: the result of get_source expressions() + calculate_is_identity: true iff this should say it is the identity. + + Returns: + An expression. + + """ + if calculate_is_identity: + calculate_output = source_expressions[0].calculate_output + else: + calculate_output = prensor.LeafNodeTensor( + tf.constant([], dtype=tf.int64), + tf.constant([], dtype=actual_type), + actual_is_repeated, + ) + return MockExpression( + declared_is_repeated, + declared_type, + name=name, + source_expressions=source_expressions, + calculate_output=calculate_output, + calculate_is_identity=calculate_is_identity, + ) diff --git a/struct2tensor/test/prensor_test_util.py b/struct2tensor/test/prensor_test_util.py index 0ee7724..b61c6b4 100644 --- a/struct2tensor/test/prensor_test_util.py +++ b/struct2tensor/test/prensor_test_util.py @@ -27,143 +27,153 @@ def create_root_node(size: int) -> prensor.RootNodeTensor: - return prensor.RootNodeTensor(tf.constant(size, dtype=tf.int64)) + return prensor.RootNodeTensor(tf.constant(size, dtype=tf.int64)) -def create_child_node(parent_index: Sequence[int], - is_repeated: bool) -> prensor.ChildNodeTensor: - return prensor.ChildNodeTensor( - tf.constant(parent_index, dtype=tf.int64), is_repeated) +def create_child_node( + parent_index: Sequence[int], is_repeated: bool +) -> prensor.ChildNodeTensor: + return prensor.ChildNodeTensor( + tf.constant(parent_index, dtype=tf.int64), is_repeated + ) -def create_optional_leaf_node(parent_index: Sequence[int], - values: Sequence[Any]) -> prensor.LeafNodeTensor: - """Creates an optional leaf node. +def create_optional_leaf_node( + parent_index: Sequence[int], values: Sequence[Any] +) -> prensor.LeafNodeTensor: + """Creates an optional leaf node. - Args: - parent_index: a list of integers that is converted to a 1-D int64 tensor. - values: a list of whatever type that the field represents. + Args: + parent_index: a list of integers that is converted to a 1-D int64 tensor. + values: a list of whatever type that the field represents. - Returns: - A PrensorField with the parent_index and values set appropriately. - """ - return prensor.LeafNodeTensor( - tf.constant(parent_index, dtype=tf.int64), tf.constant(values), False) + Returns: + A PrensorField with the parent_index and values set appropriately. + """ + return prensor.LeafNodeTensor( + tf.constant(parent_index, dtype=tf.int64), tf.constant(values), False + ) -def create_repeated_leaf_node(parent_index: Sequence[int], - values: Sequence[Any]): - """Creates a repeated PrensorField. +def create_repeated_leaf_node(parent_index: Sequence[int], values: Sequence[Any]): + """Creates a repeated PrensorField. - Args: - parent_index: a list of integers that is converted to a 1-D int64 tensor. - values: a list of whatever type that the field represents. + Args: + parent_index: a list of integers that is converted to a 1-D int64 tensor. + values: a list of whatever type that the field represents. - Returns: - A PrensorField with the parent_index and values set appropriately. - """ - return prensor.LeafNodeTensor( - tf.constant(parent_index, dtype=tf.int64), tf.constant(values), True) + Returns: + A PrensorField with the parent_index and values set appropriately. + """ + return prensor.LeafNodeTensor( + tf.constant(parent_index, dtype=tf.int64), tf.constant(values), True + ) def create_simple_prensor() -> prensor.Prensor: - """Creates a prensor expression representing a list of flat protocol buffers. - - Returns: - a RootPrensor representing: - {foo:9, foorepeated:[9]} - {foo:8, foorepeated:[8,7]} - {foo:7, foorepeated:[6]} - """ - return prensor.create_prensor_from_descendant_nodes({ - path.Path([]): - create_root_node(3), - path.Path(["foo"]): - create_optional_leaf_node([0, 1, 2], [9, 8, 7]), - path.Path(["foorepeated"]): - create_repeated_leaf_node([0, 1, 1, 2], [9, 8, 7, 6]) - }) + """Creates a prensor expression representing a list of flat protocol buffers. + + Returns: + a RootPrensor representing: + {foo:9, foorepeated:[9]} + {foo:8, foorepeated:[8,7]} + {foo:7, foorepeated:[6]} + """ + return prensor.create_prensor_from_descendant_nodes( + { + path.Path([]): create_root_node(3), + path.Path(["foo"]): create_optional_leaf_node([0, 1, 2], [9, 8, 7]), + path.Path(["foorepeated"]): create_repeated_leaf_node( + [0, 1, 1, 2], [9, 8, 7, 6] + ), + } + ) def create_broken_prensor() -> prensor.Prensor: - """Creates a broken prensor where the parent indices are out of order. - - Returns: - a RootPrensor where the parent indices are out of order. - """ - return prensor.create_prensor_from_descendant_nodes({ - path.Path([]): - create_root_node(3), - path.Path(["foo"]): - create_optional_leaf_node([1, 0, 2], [9, 8, 7]), - path.Path(["foorepeated"]): - create_repeated_leaf_node([1, 0, 1, 2], [9, 8, 7, 6]) - }) + """Creates a broken prensor where the parent indices are out of order. + + Returns: + a RootPrensor where the parent indices are out of order. + """ + return prensor.create_prensor_from_descendant_nodes( + { + path.Path([]): create_root_node(3), + path.Path(["foo"]): create_optional_leaf_node([1, 0, 2], [9, 8, 7]), + path.Path(["foorepeated"]): create_repeated_leaf_node( + [1, 0, 1, 2], [9, 8, 7, 6] + ), + } + ) def create_nested_prensor() -> prensor.Prensor: - """Creates a prensor representing a list of nested protocol buffers. - - Returns: - a prensor expression representing: - {doc:[{bar:["a"], keep_me:False}], user:[{friends:["a"]}]} - {doc:[{bar:["b","c"], keep_me:True}, {bar:["d"]}], - user:[{friends:["b", "c"]}, {friends:["d"]}]} - {user:[friends:["e"]]} - """ - return prensor.create_prensor_from_descendant_nodes({ - path.Path([]): - create_root_node(3), - path.Path(["doc"]): - create_child_node([0, 1, 1], True), - path.Path(["doc", "bar"]): - create_repeated_leaf_node([0, 1, 1, 2], ["a", "b", "c", "d"]), - path.Path(["doc", "keep_me"]): - create_optional_leaf_node([0, 1], [False, True]), - path.Path(["user"]): - create_child_node([0, 1, 1, 2], True), - path.Path(["user", "friends"]): - create_repeated_leaf_node([0, 1, 1, 2, 3], ["a", "b", "c", "d", "e"]) - }) + """Creates a prensor representing a list of nested protocol buffers. + + Returns: + a prensor expression representing: + {doc:[{bar:["a"], keep_me:False}], user:[{friends:["a"]}]} + {doc:[{bar:["b","c"], keep_me:True}, {bar:["d"]}], + user:[{friends:["b", "c"]}, {friends:["d"]}]} + {user:[friends:["e"]]} + """ + return prensor.create_prensor_from_descendant_nodes( + { + path.Path([]): create_root_node(3), + path.Path(["doc"]): create_child_node([0, 1, 1], True), + path.Path(["doc", "bar"]): create_repeated_leaf_node( + [0, 1, 1, 2], ["a", "b", "c", "d"] + ), + path.Path(["doc", "keep_me"]): create_optional_leaf_node( + [0, 1], [False, True] + ), + path.Path(["user"]): create_child_node([0, 1, 1, 2], True), + path.Path(["user", "friends"]): create_repeated_leaf_node( + [0, 1, 1, 2, 3], ["a", "b", "c", "d", "e"] + ), + } + ) def create_nested_prensor_with_lenient_field_names() -> prensor.Prensor: - """Creates a prensor representing a data structure with proto-invalid names. - - Returns: - a prensor expression representing: - {doc:[{bar.baz:[1], keep_me/x:False}], user:[{friends!:):[1]}]} - {doc:[{bar:[2, 3], keep_me/x:True}, {bar:[4]}], - user:[{friends!:):[2, 3]}, {friends!:):[4]}]} - {user:[friends!:):[5]]} - """ - return prensor.create_prensor_from_descendant_nodes({ - path.Path([], validate_step_format=False): create_root_node(3), - path.Path(["doc"], validate_step_format=False): create_child_node( - [0, 1, 1], True - ), - path.Path( - ["doc", "bar.baz"], validate_step_format=False - ): create_repeated_leaf_node([0, 1, 1, 2], [1, 2, 3, 4]), - path.Path( - ["doc", "keep_me/x"], validate_step_format=False - ): create_optional_leaf_node( - [0, 1], - [False, True], - ), - path.Path(["user"], validate_step_format=False): create_child_node( - [0, 1, 1, 2], True - ), - path.Path( - ["user", "friends!:)"], validate_step_format=False - ): create_repeated_leaf_node([0, 1, 1, 2, 3], [1, 2, 3, 4, 5]), - }) + """Creates a prensor representing a data structure with proto-invalid names. + + Returns: + a prensor expression representing: + {doc:[{bar.baz:[1], keep_me/x:False}], user:[{friends!:):[1]}]} + {doc:[{bar:[2, 3], keep_me/x:True}, {bar:[4]}], + user:[{friends!:):[2, 3]}, {friends!:):[4]}]} + {user:[friends!:):[5]]} + """ + return prensor.create_prensor_from_descendant_nodes( + { + path.Path([], validate_step_format=False): create_root_node(3), + path.Path(["doc"], validate_step_format=False): create_child_node( + [0, 1, 1], True + ), + path.Path( + ["doc", "bar.baz"], validate_step_format=False + ): create_repeated_leaf_node([0, 1, 1, 2], [1, 2, 3, 4]), + path.Path( + ["doc", "keep_me/x"], validate_step_format=False + ): create_optional_leaf_node( + [0, 1], + [False, True], + ), + path.Path(["user"], validate_step_format=False): create_child_node( + [0, 1, 1, 2], True + ), + path.Path( + ["user", "friends!:)"], validate_step_format=False + ): create_repeated_leaf_node([0, 1, 1, 2, 3], [1, 2, 3, 4, 5]), + } + ) def create_big_prensor_schema(): - """Create a schema that is aligned with create_big_prensor.""" - return text_format.Parse( - """ + """Create a schema that is aligned with create_big_prensor.""" + return text_format.Parse( + """ int_domain: { name: 'zero_to_ten' min: 0 @@ -207,56 +217,60 @@ def create_big_prensor_schema(): } } } - """, schema_pb2.Schema()) + """, + schema_pb2.Schema(), + ) def create_big_prensor(): - """Creates a prensor representing a list of nested protocol buffers. - - Returns: - a prensor expression representing: - { - foo: 9, - foorepeated: [9], - doc: [{bar:["a"], keep_me:False}], - user: [{friends:["a"]}] - }, - { - foo: 8, - foorepeated: [8, 7], - doc: [{bar:["b","c"],keep_me:True},{bar:["d"]}], - user: [{friends:["b", "c"]},{friends:["d"]}], - }, - { - foo: 7, - foorepeated: [6], - user: [{friends:["e"]}] - } - """ - return prensor.create_prensor_from_descendant_nodes({ - path.Path([]): - create_root_node(3), - path.Path(["foo"]): - create_optional_leaf_node([0, 1, 2], [9, 8, 7]), - path.Path(["doc"]): - create_child_node([0, 1, 1], True), - path.Path(["doc", "keep_me"]): - create_optional_leaf_node([0, 1], [False, True]), - path.Path(["doc", "bar"]): - create_repeated_leaf_node([0, 1, 1, 2], ["a", "b", "c", "d"]), - path.Path(["user"]): - create_child_node([0, 1, 1, 2], True), - path.Path(["user", "friends"]): - create_repeated_leaf_node([0, 1, 1, 2, 3], ["a", "b", "c", "d", "e"]), - path.Path(["foorepeated"]): - create_repeated_leaf_node([0, 1, 1, 2], [9, 8, 7, 6]), - }) + """Creates a prensor representing a list of nested protocol buffers. + + Returns: + a prensor expression representing: + { + foo: 9, + foorepeated: [9], + doc: [{bar:["a"], keep_me:False}], + user: [{friends:["a"]}] + }, + { + foo: 8, + foorepeated: [8, 7], + doc: [{bar:["b","c"],keep_me:True},{bar:["d"]}], + user: [{friends:["b", "c"]},{friends:["d"]}], + }, + { + foo: 7, + foorepeated: [6], + user: [{friends:["e"]}] + } + """ + return prensor.create_prensor_from_descendant_nodes( + { + path.Path([]): create_root_node(3), + path.Path(["foo"]): create_optional_leaf_node([0, 1, 2], [9, 8, 7]), + path.Path(["doc"]): create_child_node([0, 1, 1], True), + path.Path(["doc", "keep_me"]): create_optional_leaf_node( + [0, 1], [False, True] + ), + path.Path(["doc", "bar"]): create_repeated_leaf_node( + [0, 1, 1, 2], ["a", "b", "c", "d"] + ), + path.Path(["user"]): create_child_node([0, 1, 1, 2], True), + path.Path(["user", "friends"]): create_repeated_leaf_node( + [0, 1, 1, 2, 3], ["a", "b", "c", "d", "e"] + ), + path.Path(["foorepeated"]): create_repeated_leaf_node( + [0, 1, 1, 2], [9, 8, 7, 6] + ), + } + ) def create_deep_prensor_schema(): - """Create a schema that is aligned with create_deep_prensor.""" - return text_format.Parse( - """ + """Create a schema that is aligned with create_deep_prensor.""" + return text_format.Parse( + """ int_domain: { name: 'zero_to_ten' min: 0 @@ -305,99 +319,104 @@ def create_deep_prensor_schema(): } } } - """, schema_pb2.Schema()) + """, + schema_pb2.Schema(), + ) def create_deep_prensor(): - """Creates prensor with three layers: root, event, and doc. - - Returns: - a prensor expression representing: - { - foo: 9, - foorepeated: [9], - user: [{friends:["a"]}], - event: [{doc:[{bar:["a"], keep_me:False}]}] - }, - { - foo: 8, - foorepeated: [8,7], - user: [{friends:["b", "c"]}, {friends:["d"]}], - event: [{doc:[{bar:["b","c"], keep_me:True},{bar:["d"]}]}] - }, - { - foo:7, - foorepeated: [6], - user: [{friends:["e"]}], - event: [{}] - } - """ - return prensor.create_prensor_from_descendant_nodes({ - path.Path([]): - create_root_node(3), - path.Path(["event"]): - create_child_node([0, 1, 2], True), - path.Path(["event", "doc"]): - create_child_node([0, 1, 1], True), - path.Path(["event", "doc", "bar"]): - create_repeated_leaf_node([0, 1, 1, 2], ["a", "b", "c", "d"]), - path.Path(["event", "doc", "keep_me"]): - create_optional_leaf_node([0, 1], [False, True]), - path.Path(["foo"]): - create_optional_leaf_node([0, 1, 2], [9, 8, 7]), - path.Path(["foorepeated"]): - create_repeated_leaf_node([0, 1, 1, 2], [9, 8, 7, 6]), - path.Path(["user"]): - create_child_node([0, 1, 1, 2], True), - path.Path(["user", "friends"]): - create_repeated_leaf_node([0, 1, 1, 2, 3], ["a", "b", "c", "d", "e"]) - }) + """Creates prensor with three layers: root, event, and doc. + + Returns: + a prensor expression representing: + { + foo: 9, + foorepeated: [9], + user: [{friends:["a"]}], + event: [{doc:[{bar:["a"], keep_me:False}]}] + }, + { + foo: 8, + foorepeated: [8,7], + user: [{friends:["b", "c"]}, {friends:["d"]}], + event: [{doc:[{bar:["b","c"], keep_me:True},{bar:["d"]}]}] + }, + { + foo:7, + foorepeated: [6], + user: [{friends:["e"]}], + event: [{}] + } + """ + return prensor.create_prensor_from_descendant_nodes( + { + path.Path([]): create_root_node(3), + path.Path(["event"]): create_child_node([0, 1, 2], True), + path.Path(["event", "doc"]): create_child_node([0, 1, 1], True), + path.Path(["event", "doc", "bar"]): create_repeated_leaf_node( + [0, 1, 1, 2], ["a", "b", "c", "d"] + ), + path.Path(["event", "doc", "keep_me"]): create_optional_leaf_node( + [0, 1], [False, True] + ), + path.Path(["foo"]): create_optional_leaf_node([0, 1, 2], [9, 8, 7]), + path.Path(["foorepeated"]): create_repeated_leaf_node( + [0, 1, 1, 2], [9, 8, 7, 6] + ), + path.Path(["user"]): create_child_node([0, 1, 1, 2], True), + path.Path(["user", "friends"]): create_repeated_leaf_node( + [0, 1, 1, 2, 3], ["a", "b", "c", "d", "e"] + ), + } + ) def create_four_layer_prensor(): - """Creates prensor with four layers: root, event, doc, and nested_child. - - Returns: - a prensor expression representing: - { - event: { - doc: { - nested_child: { - bar:["a"], - keep_me:False + """Creates prensor with four layers: root, event, doc, and nested_child. + + Returns: + a prensor expression representing: + { + event: { + doc: { + nested_child: { + bar:["a"], + keep_me:False + } } } } - } - { - event: { - doc: { - nested_child: { - bar:["b","c"], - keep_me:True + { + event: { + doc: { + nested_child: { + bar:["b","c"], + keep_me:True + }, + nested_child: { + bar:["d"] + } }, - nested_child: { - bar:["d"] + doc: { + nested_child: {} } - }, - doc: { - nested_child: {} } } - } - {event: {}} - """ - return prensor.create_prensor_from_descendant_nodes({ - path.Path([]): - create_root_node(3), - path.Path(["event"]): - create_child_node([0, 1, 2], True), - path.Path(["event", "doc"]): - create_child_node([0, 1, 1], True), - path.Path(["event", "doc", "nested_child"]): - create_child_node([0, 1, 1, 2], True), - path.Path(["event", "doc", "nested_child", "bar"]): - create_repeated_leaf_node([0, 1, 1, 2], ["a", "b", "c", "d"]), - path.Path(["event", "doc", "nested_child", "keep_me"]): - create_optional_leaf_node([0, 1], [False, True]) - }) + {event: {}} + """ + return prensor.create_prensor_from_descendant_nodes( + { + path.Path([]): create_root_node(3), + path.Path(["event"]): create_child_node([0, 1, 2], True), + path.Path(["event", "doc"]): create_child_node([0, 1, 1], True), + path.Path(["event", "doc", "nested_child"]): create_child_node( + [0, 1, 1, 2], True + ), + path.Path( + ["event", "doc", "nested_child", "bar"] + ): create_repeated_leaf_node([0, 1, 1, 2], ["a", "b", "c", "d"]), + path.Path( + ["event", "doc", "nested_child", "keep_me"] + ): create_optional_leaf_node([0, 1], [False, True]), + } + ) diff --git a/struct2tensor/tools/build_docs.py b/struct2tensor/tools/build_docs.py index 8ad3f36..e9643b9 100644 --- a/struct2tensor/tools/build_docs.py +++ b/struct2tensor/tools/build_docs.py @@ -35,117 +35,127 @@ flags.DEFINE_string( "code_url_prefix", "https://github.com/google/struct2tensor/blob/master/struct2tensor", - "The url prefix for links to code.") + "The url prefix for links to code.", +) -flags.DEFINE_bool("search_hints", True, - "Include metadata search hints in the generated files") +flags.DEFINE_bool( + "search_hints", True, "Include metadata search hints in the generated files" +) FLAGS = flags.FLAGS def build_docs(output_dir: pathlib.Path) -> None: - """Generates the api docs for struct2tensor and expression_impl. - - We need to splice the generated docs together, and create a new table of - contents. - - Args: - output_dir: A pathlib Path that is the top level dir where docs are - generated. - - Returns: - None - """ - s2t_out = pathlib.Path(tempfile.mkdtemp()) - doc_generator = generate_lib.DocGenerator( - root_title="Struct2Tensor", - py_modules=[("s2t", s2t)], - code_url_prefix=FLAGS.code_url_prefix, - search_hints=FLAGS.search_hints, - # explicit_package_contents_filter ensures that only modules imported - # directly from s2t/__init__.py are documented in the location that - # defines them, instead of every location that imports them. - callbacks=[ - public_api.explicit_package_contents_filter, _filter_module_attributes - ]) - - doc_generator.build(s2t_out) - - expr_impl_out = pathlib.Path(tempfile.mkdtemp()) - doc_generator = generate_lib.DocGenerator( - root_title="Struct2Tensor-expression_impl", - py_modules=[("expression_impl", expression_impl)], - code_url_prefix=FLAGS.code_url_prefix + "/expression_impl", - search_hints=FLAGS.search_hints, - # explicit_package_contents_filter ensures that only modules imported - # directly from s2t/expression_impl/__init__.py are documented in the - # location that defines them, instead of every location that imports them. - callbacks=[ - public_api.explicit_package_contents_filter, _filter_module_attributes - ]) - doc_generator.build(expr_impl_out) - - def splice(name, tmp_dir): - shutil.rmtree(output_dir / name, ignore_errors=True) - shutil.copytree(tmp_dir / name, output_dir / name) - shutil.copy(tmp_dir / f"{name}.md", output_dir / f"{name}.md") - with contextlib.suppress(FileNotFoundError): - shutil.copy(tmp_dir / name / "_redirects.yaml", - output_dir / name / "_redirects.yaml") - shutil.copy(tmp_dir / name / "_toc.yaml", output_dir / name / "_toc.yaml") - - splice("s2t", s2t_out) - splice("expression_impl", expr_impl_out) - - toc_path = output_dir / "_toc.yaml" - toc_text = yaml.dump({ - "toc": [{ - "include": "/api_docs/python/s2t/_toc.yaml" - }, { - "break": True - }, { - "include": "/api_docs/python/expression_impl/_toc.yaml" - }] - }) - toc_path.write_text(toc_text) + """Generates the api docs for struct2tensor and expression_impl. + + We need to splice the generated docs together, and create a new table of + contents. + + Args: + output_dir: A pathlib Path that is the top level dir where docs are + generated. + + Returns: + None + """ + s2t_out = pathlib.Path(tempfile.mkdtemp()) + doc_generator = generate_lib.DocGenerator( + root_title="Struct2Tensor", + py_modules=[("s2t", s2t)], + code_url_prefix=FLAGS.code_url_prefix, + search_hints=FLAGS.search_hints, + # explicit_package_contents_filter ensures that only modules imported + # directly from s2t/__init__.py are documented in the location that + # defines them, instead of every location that imports them. + callbacks=[ + public_api.explicit_package_contents_filter, + _filter_module_attributes, + ], + ) + + doc_generator.build(s2t_out) + + expr_impl_out = pathlib.Path(tempfile.mkdtemp()) + doc_generator = generate_lib.DocGenerator( + root_title="Struct2Tensor-expression_impl", + py_modules=[("expression_impl", expression_impl)], + code_url_prefix=FLAGS.code_url_prefix + "/expression_impl", + search_hints=FLAGS.search_hints, + # explicit_package_contents_filter ensures that only modules imported + # directly from s2t/expression_impl/__init__.py are documented in the + # location that defines them, instead of every location that imports them. + callbacks=[ + public_api.explicit_package_contents_filter, + _filter_module_attributes, + ], + ) + doc_generator.build(expr_impl_out) + + def splice(name, tmp_dir): + shutil.rmtree(output_dir / name, ignore_errors=True) + shutil.copytree(tmp_dir / name, output_dir / name) + shutil.copy(tmp_dir / f"{name}.md", output_dir / f"{name}.md") + with contextlib.suppress(FileNotFoundError): + shutil.copy( + tmp_dir / name / "_redirects.yaml", + output_dir / name / "_redirects.yaml", + ) + shutil.copy(tmp_dir / name / "_toc.yaml", output_dir / name / "_toc.yaml") + + splice("s2t", s2t_out) + splice("expression_impl", expr_impl_out) + + toc_path = output_dir / "_toc.yaml" + toc_text = yaml.dump( + { + "toc": [ + {"include": "/api_docs/python/s2t/_toc.yaml"}, + {"break": True}, + {"include": "/api_docs/python/expression_impl/_toc.yaml"}, + ] + } + ) + toc_path.write_text(toc_text) def _filter_module_attributes(path, parent, children): - """Filter out module attirubtes. - - This removes attributes that are a gen rule attribute. - The custom ops that need gen rules will have their docs exposed - from the python source file. No need to also get api docs from the generated - files. - - Args: - path: path to parent attribute. - parent: the attribute (module or class). - children: the children attributes of parent. - - Returns: - The new, filtered, children of the current attribute. - """ - del path - skip_module_attributes = { - "gen_decode_proto_sparse", - "gen_decode_proto_map_op", - "gen_equi_join_indices", - "gen_parquet_dataset", - "gen_run_length_before" - } - if inspect.ismodule(parent): - children = [(name, child) - for (name, child) in children - if name not in skip_module_attributes] - return children + """Filter out module attirubtes. + + This removes attributes that are a gen rule attribute. + The custom ops that need gen rules will have their docs exposed + from the python source file. No need to also get api docs from the generated + files. + + Args: + path: path to parent attribute. + parent: the attribute (module or class). + children: the children attributes of parent. + + Returns: + The new, filtered, children of the current attribute. + """ + del path + skip_module_attributes = { + "gen_decode_proto_sparse", + "gen_decode_proto_map_op", + "gen_equi_join_indices", + "gen_parquet_dataset", + "gen_run_length_before", + } + if inspect.ismodule(parent): + children = [ + (name, child) + for (name, child) in children + if name not in skip_module_attributes + ] + return children def main(unused_argv): - del unused_argv - build_docs(output_dir=pathlib.Path(FLAGS.output_dir)) + del unused_argv + build_docs(output_dir=pathlib.Path(FLAGS.output_dir)) if __name__ == "__main__": - app.run(main) + app.run(main) diff --git a/struct2tensor/version.py b/struct2tensor/version.py index b61835f..6e1deb3 100644 --- a/struct2tensor/version.py +++ b/struct2tensor/version.py @@ -15,4 +15,4 @@ """Contains the version string of struct2tensor.""" # Note that setup.py uses this version. -__version__ = '0.49.0.dev' +__version__ = "0.49.0.dev"