Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More robust first elem check in encode/cast example #3402

Merged
merged 2 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 28 additions & 6 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool) -> Tuple[Any, boo
Cast pytorch/tensorflow/pandas objects to python numpy array/lists.
It works recursively.

To avoid iterating over possibly long lists, it first checks if the first element that is not None has to be casted.
To avoid iterating over possibly long lists, it first checks (recursively) if the first element that is not None or empty (if it is a sequence) has to be casted.
If the first element needs to be casted, then all the elements of the list will be casted, otherwise they'll stay the same.
This trick allows to cast objects that contain tokenizers outputs without iterating over every single token for example.

Expand Down Expand Up @@ -221,7 +221,7 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool) -> Tuple[Any, boo
elif isinstance(obj, (list, tuple)):
if len(obj) > 0:
for first_elmt in obj:
if first_elmt is not None:
if _check_non_null_non_empty_recursive(first_elmt):
break
casted_first_elmt, has_changed_first_elmt = _cast_to_python_objects(
first_elmt, only_1d_for_numpy=only_1d_for_numpy
Expand All @@ -244,7 +244,7 @@ def cast_to_python_objects(obj: Any, only_1d_for_numpy=False) -> Any:
Cast numpy/pytorch/tensorflow/pandas objects to python lists.
It works recursively.

To avoid iterating over possibly long lists, it first checks if the first element that is not None has to be casted.
To avoid iterating over possibly long lists, it first checks (recursively) if the first element that is not None or empty (if it is a sequence) has to be casted.
If the first element needs to be casted, then all the elements of the list will be casted, otherwise they'll stay the same.
This trick allows to cast objects that contain tokenizers outputs without iterating over every single token for example.

Expand Down Expand Up @@ -774,6 +774,28 @@ class Sequence:
]


def _check_non_null_non_empty_recursive(obj, schema: Optional[FeatureType] = None) -> bool:
"""
Check if the object is not None.
If the object is a list or a tuple, recursively check the first element of the sequence and stop if at any point the first element is not a sequence or is an empty sequence.
"""
if obj is None:
return False
elif isinstance(obj, (list, tuple)) and (schema is None or isinstance(schema, (list, tuple, Sequence))):
if len(obj) > 0:
if schema is None:
pass
elif isinstance(schema, (list, tuple)):
schema = schema[0]
else:
schema = schema.feature
return _check_non_null_non_empty_recursive(obj[0], schema)
else:
return False
else:
return True


def get_nested_type(schema: FeatureType) -> pa.DataType:
"""
get_nested_type() converts a datasets.FeatureType into a pyarrow.DataType, and acts as the inverse of
Expand Down Expand Up @@ -810,7 +832,7 @@ def encode_nested_example(schema, obj):
"""Encode a nested example.
This is used since some features (in particular ClassLabel) have some logic during encoding.

To avoid iterating over possibly long lists, it first checks if the first element that is not None has to be encoded.
To avoid iterating over possibly long lists, it first checks (recursively) if the first element that is not None or empty (if it is a sequence) has to be encoded.
If the first element needs to be encoded, then all the elements of the list will be encoded, otherwise they'll stay the same.
"""
# Nested structures: we allow dict, list/tuples, sequences
Expand All @@ -825,7 +847,7 @@ def encode_nested_example(schema, obj):
else:
if len(obj) > 0:
for first_elmt in obj:
if first_elmt is not None:
if _check_non_null_non_empty_recursive(first_elmt, sub_schema):
break
if encode_nested_example(sub_schema, first_elmt) != first_elmt:
return [encode_nested_example(sub_schema, o) for o in obj]
Expand Down Expand Up @@ -853,7 +875,7 @@ def encode_nested_example(schema, obj):
else:
if len(obj) > 0:
for first_elmt in obj:
if first_elmt is not None:
if _check_non_null_non_empty_recursive(first_elmt, schema.feature):
break
# be careful when comparing tensors here
if not isinstance(first_elmt, list) or encode_nested_example(schema.feature, first_elmt) != first_elmt:
Expand Down
17 changes: 17 additions & 0 deletions tests/features/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,23 @@ def test_encode_nested_example_sequence_with_none():
assert result is None


def test_encode_batch_with_example_with_empty_first_elem():
features = Features(
{
"x": Sequence(Sequence(ClassLabel(names=["a", "b"]))),
}
)
encoded_batch = features.encode_batch(
{
"x": [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be properly tested, does the first element has to be an empty list ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because encode_batch calls encode_nested_example which goes from the beginning of the list and tries to find the first element that is "good enough" to perform additional checks, and we consider an element "good enough" if it's not None or if it's not an empty sequence (possible nested). Previously, we would stop on the first element that is not None, but this could lead to issues such as the one this PR fixes.

[["a"], ["b"]],
[[], ["b"]],
]
}
)
assert encoded_batch == {"x": [[[0], [1]], [[], [1]]]}


def iternumpy(key1, value1, value2):
if value1.dtype != value2.dtype: # check only for dtype
raise AssertionError(
Expand Down