Skip to content

Commit

Permalink
More robust first elem check in encode/cast example (#3402)
Browse files Browse the repository at this point in the history
* More robust first elem check in encode/cast example

* Type hints and better docstring
  • Loading branch information
mariosasko committed Dec 8, 2021
1 parent 49bb250 commit 18e0adf
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 6 deletions.
34 changes: 28 additions & 6 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool) -> Tuple[Any, boo
Cast pytorch/tensorflow/pandas objects to python numpy array/lists.
It works recursively.
To avoid iterating over possibly long lists, it first checks if the first element that is not None has to be casted.
To avoid iterating over possibly long lists, it first checks (recursively) if the first element that is not None or empty (if it is a sequence) has to be casted.
If the first element needs to be casted, then all the elements of the list will be casted, otherwise they'll stay the same.
This trick allows to cast objects that contain tokenizers outputs without iterating over every single token for example.
Expand Down Expand Up @@ -221,7 +221,7 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool) -> Tuple[Any, boo
elif isinstance(obj, (list, tuple)):
if len(obj) > 0:
for first_elmt in obj:
if first_elmt is not None:
if _check_non_null_non_empty_recursive(first_elmt):
break
casted_first_elmt, has_changed_first_elmt = _cast_to_python_objects(
first_elmt, only_1d_for_numpy=only_1d_for_numpy
Expand All @@ -244,7 +244,7 @@ def cast_to_python_objects(obj: Any, only_1d_for_numpy=False) -> Any:
Cast numpy/pytorch/tensorflow/pandas objects to python lists.
It works recursively.
To avoid iterating over possibly long lists, it first checks if the first element that is not None has to be casted.
To avoid iterating over possibly long lists, it first checks (recursively) if the first element that is not None or empty (if it is a sequence) has to be casted.
If the first element needs to be casted, then all the elements of the list will be casted, otherwise they'll stay the same.
This trick allows to cast objects that contain tokenizers outputs without iterating over every single token for example.
Expand Down Expand Up @@ -774,6 +774,28 @@ class Sequence:
]


def _check_non_null_non_empty_recursive(obj, schema: Optional[FeatureType] = None) -> bool:
"""
Check if the object is not None.
If the object is a list or a tuple, recursively check the first element of the sequence and stop if at any point the first element is not a sequence or is an empty sequence.
"""
if obj is None:
return False
elif isinstance(obj, (list, tuple)) and (schema is None or isinstance(schema, (list, tuple, Sequence))):
if len(obj) > 0:
if schema is None:
pass
elif isinstance(schema, (list, tuple)):
schema = schema[0]
else:
schema = schema.feature
return _check_non_null_non_empty_recursive(obj[0], schema)
else:
return False
else:
return True


def get_nested_type(schema: FeatureType) -> pa.DataType:
"""
get_nested_type() converts a datasets.FeatureType into a pyarrow.DataType, and acts as the inverse of
Expand Down Expand Up @@ -810,7 +832,7 @@ def encode_nested_example(schema, obj):
"""Encode a nested example.
This is used since some features (in particular ClassLabel) have some logic during encoding.
To avoid iterating over possibly long lists, it first checks if the first element that is not None has to be encoded.
To avoid iterating over possibly long lists, it first checks (recursively) if the first element that is not None or empty (if it is a sequence) has to be encoded.
If the first element needs to be encoded, then all the elements of the list will be encoded, otherwise they'll stay the same.
"""
# Nested structures: we allow dict, list/tuples, sequences
Expand All @@ -825,7 +847,7 @@ def encode_nested_example(schema, obj):
else:
if len(obj) > 0:
for first_elmt in obj:
if first_elmt is not None:
if _check_non_null_non_empty_recursive(first_elmt, sub_schema):
break
if encode_nested_example(sub_schema, first_elmt) != first_elmt:
return [encode_nested_example(sub_schema, o) for o in obj]
Expand Down Expand Up @@ -853,7 +875,7 @@ def encode_nested_example(schema, obj):
else:
if len(obj) > 0:
for first_elmt in obj:
if first_elmt is not None:
if _check_non_null_non_empty_recursive(first_elmt, schema.feature):
break
# be careful when comparing tensors here
if not isinstance(first_elmt, list) or encode_nested_example(schema.feature, first_elmt) != first_elmt:
Expand Down
17 changes: 17 additions & 0 deletions tests/features/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,23 @@ def test_encode_nested_example_sequence_with_none():
assert result is None


def test_encode_batch_with_example_with_empty_first_elem():
features = Features(
{
"x": Sequence(Sequence(ClassLabel(names=["a", "b"]))),
}
)
encoded_batch = features.encode_batch(
{
"x": [
[["a"], ["b"]],
[[], ["b"]],
]
}
)
assert encoded_batch == {"x": [[[0], [1]], [[], [1]]]}


def iternumpy(key1, value1, value2):
if value1.dtype != value2.dtype: # check only for dtype
raise AssertionError(
Expand Down

1 comment on commit 18e0adf

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==3.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.064456 / 0.011353 (0.053103) 0.003706 / 0.011008 (-0.007302) 0.028371 / 0.038508 (-0.010137) 0.032020 / 0.023109 (0.008911) 0.258883 / 0.275898 (-0.017015) 0.302997 / 0.323480 (-0.020483) 0.081466 / 0.007986 (0.073480) 0.004561 / 0.004328 (0.000233) 0.008247 / 0.004250 (0.003996) 0.035508 / 0.037052 (-0.001544) 0.257490 / 0.258489 (-0.000999) 0.298945 / 0.293841 (0.005104) 0.076406 / 0.128546 (-0.052141) 0.007936 / 0.075646 (-0.067710) 0.228691 / 0.419271 (-0.190581) 0.042359 / 0.043533 (-0.001174) 0.258026 / 0.255139 (0.002887) 0.281574 / 0.283200 (-0.001626) 0.074387 / 0.141683 (-0.067296) 1.506933 / 1.452155 (0.054778) 1.563466 / 1.492716 (0.070749)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.226802 / 0.018006 (0.208795) 0.443700 / 0.000490 (0.443210) 0.002440 / 0.000200 (0.002240) 0.000076 / 0.000054 (0.000021)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.033079 / 0.037411 (-0.004332) 0.020973 / 0.014526 (0.006448) 0.028437 / 0.176557 (-0.148119) 0.175394 / 0.737135 (-0.561741) 0.029998 / 0.296338 (-0.266341)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.387787 / 0.215209 (0.172578) 3.871218 / 2.077655 (1.793563) 1.701658 / 1.504120 (0.197538) 1.540654 / 1.541195 (-0.000540) 1.619225 / 1.468490 (0.150735) 0.373044 / 4.584777 (-4.211732) 4.238523 / 3.745712 (0.492811) 3.350717 / 5.269862 (-1.919144) 0.789873 / 4.565676 (-3.775804) 0.045448 / 0.424275 (-0.378827) 0.010192 / 0.007607 (0.002585) 0.488416 / 0.226044 (0.262371) 4.847399 / 2.268929 (2.578470) 2.099292 / 55.444624 (-53.345333) 1.737064 / 6.876477 (-5.139413) 1.867369 / 2.142072 (-0.274704) 0.482872 / 4.805227 (-4.322356) 0.104030 / 6.500664 (-6.396634) 0.052769 / 0.075469 (-0.022700)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.350020 / 1.841788 (-0.491768) 11.507824 / 8.074308 (3.433515) 23.641922 / 10.191392 (13.450530) 0.624974 / 0.680424 (-0.055449) 0.450480 / 0.534201 (-0.083721) 0.330439 / 0.579283 (-0.248845) 0.462275 / 0.434364 (0.027911) 0.226467 / 0.540337 (-0.313870) 0.235522 / 1.386936 (-1.151414)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.062808 / 0.011353 (0.051455) 0.003815 / 0.011008 (-0.007194) 0.026578 / 0.038508 (-0.011930) 0.031201 / 0.023109 (0.008091) 0.258658 / 0.275898 (-0.017240) 0.297484 / 0.323480 (-0.025996) 0.079804 / 0.007986 (0.071819) 0.004556 / 0.004328 (0.000227) 0.006631 / 0.004250 (0.002380) 0.042329 / 0.037052 (0.005276) 0.254725 / 0.258489 (-0.003764) 0.296405 / 0.293841 (0.002564) 0.075850 / 0.128546 (-0.052696) 0.007973 / 0.075646 (-0.067673) 0.224582 / 0.419271 (-0.194690) 0.040891 / 0.043533 (-0.002642) 0.260747 / 0.255139 (0.005608) 0.283252 / 0.283200 (0.000053) 0.075767 / 0.141683 (-0.065916) 1.556745 / 1.452155 (0.104590) 1.579824 / 1.492716 (0.087107)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.255038 / 0.018006 (0.237032) 0.447096 / 0.000490 (0.446606) 0.006810 / 0.000200 (0.006610) 0.000280 / 0.000054 (0.000225)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.030377 / 0.037411 (-0.007034) 0.020141 / 0.014526 (0.005616) 0.026359 / 0.176557 (-0.150198) 0.172987 / 0.737135 (-0.564148) 0.026966 / 0.296338 (-0.269372)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.384263 / 0.215209 (0.169054) 3.826681 / 2.077655 (1.749027) 1.691373 / 1.504120 (0.187253) 1.559712 / 1.541195 (0.018517) 1.672379 / 1.468490 (0.203889) 0.369508 / 4.584777 (-4.215269) 4.260061 / 3.745712 (0.514349) 1.939068 / 5.269862 (-3.330793) 0.793298 / 4.565676 (-3.772379) 0.044518 / 0.424275 (-0.379757) 0.010364 / 0.007607 (0.002757) 0.475362 / 0.226044 (0.249318) 4.763381 / 2.268929 (2.494453) 2.078293 / 55.444624 (-53.366331) 1.765534 / 6.876477 (-5.110942) 1.889915 / 2.142072 (-0.252158) 0.475719 / 4.805227 (-4.329509) 0.101620 / 6.500664 (-6.399044) 0.052045 / 0.075469 (-0.023424)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.506190 / 1.841788 (-0.335598) 11.740585 / 8.074308 (3.666277) 26.651356 / 10.191392 (16.459964) 0.707990 / 0.680424 (0.027566) 0.522295 / 0.534201 (-0.011906) 0.369513 / 0.579283 (-0.209770) 0.482216 / 0.434364 (0.047852) 0.230486 / 0.540337 (-0.309851) 0.243948 / 1.386936 (-1.142988)

CML watermark

Please sign in to comment.